rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,10 @@
1
1
  """Classes to implement dataset materialization."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from typing import Any, Generic, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
  from rasterio.enums import Resampling
9
9
 
10
10
  from rslearn.config import (
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
25
25
  from .remap import Remapper, load_remapper
26
26
  from .window import Window
27
27
 
28
- Materializers = ClassRegistry()
28
+ _MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
29
+
30
+
31
+ class _MaterializerRegistry(dict[str, type["Materializer"]]):
32
+ """Registry for Materializer classes."""
33
+
34
+ def register(
35
+ self, name: str
36
+ ) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
37
+ """Decorator to register a materializer class."""
38
+
39
+ def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
40
+ self[name] = cls
41
+ return cls
42
+
43
+ return decorator
44
+
45
+
46
+ Materializers = _MaterializerRegistry()
47
+
29
48
 
30
49
  LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
31
50
 
rslearn/dataset/remap.py CHANGED
@@ -1,18 +1,42 @@
1
1
  """Classes to remap raster values."""
2
2
 
3
- from typing import Any
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
6
  import numpy as np
6
7
  import numpy.typing as npt
7
- from class_registry import ClassRegistry
8
8
 
9
- Remappers = ClassRegistry()
9
+ _RemapperT = TypeVar("_RemapperT", bound="Remapper")
10
+
11
+
12
+ class _RemapperRegistry(dict[str, type["Remapper"]]):
13
+ """Registry for Remapper classes."""
14
+
15
+ def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
16
+ """Decorator to register a remapper class."""
17
+
18
+ def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
19
+ self[name] = cls
20
+ return cls
21
+
22
+ return decorator
23
+
24
+
25
+ Remappers = _RemapperRegistry()
10
26
  """Registry of Remapper implementations."""
11
27
 
12
28
 
13
29
  class Remapper:
14
30
  """An abstract class that remaps pixel values based on layer configuration."""
15
31
 
32
+ def __init__(self, config: dict[str, Any]) -> None:
33
+ """Initialize a Remapper.
34
+
35
+ Args:
36
+ config: the config dict for this remapper.
37
+ """
38
+ pass
39
+
16
40
  def __call__(
17
41
  self, array: npt.NDArray[Any], dtype: npt.DTypeLike
18
42
  ) -> npt.NDArray[Any]:
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
67
91
 
68
92
  def load_remapper(config: dict[str, Any]) -> Remapper:
69
93
  """Load a remapper from a configuration dictionary."""
70
- return Remappers.get(config["name"], config=config)
94
+ cls = Remappers[config["name"]]
95
+ return cls(config)
rslearn/main.py CHANGED
@@ -4,6 +4,7 @@ import argparse
4
4
  import multiprocessing
5
5
  import random
6
6
  import sys
7
+ import time
7
8
  from collections.abc import Callable
8
9
  from datetime import UTC, datetime, timedelta
9
10
  from typing import Any, TypeVar
@@ -19,8 +20,18 @@ from rslearn.const import WGS84_EPSG
19
20
  from rslearn.data_sources import Item, data_source_from_config
20
21
  from rslearn.dataset import Dataset, Window, WindowLayerData
21
22
  from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
23
+ from rslearn.dataset.handler_summaries import (
24
+ ErrorOutcome,
25
+ IngestCounts,
26
+ IngestDatasetJobsSummary,
27
+ LayerIngestSummary,
28
+ MaterializeDatasetWindowsSummary,
29
+ PrepareDatasetWindowsSummary,
30
+ UnknownIngestCounts,
31
+ )
22
32
  from rslearn.dataset.index import DatasetIndex
23
33
  from rslearn.dataset.manage import (
34
+ AttemptsCounter,
24
35
  materialize_dataset_windows,
25
36
  prepare_dataset_windows,
26
37
  retry,
@@ -287,7 +298,7 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
287
298
 
288
299
 
289
300
  def apply_on_windows(
290
- f: Callable[[list[Window]], None],
301
+ f: Callable[[list[Window]], Any],
291
302
  dataset: Dataset,
292
303
  group: str | list[str] | None = None,
293
304
  names: list[str] | None = None,
@@ -367,7 +378,7 @@ def apply_on_windows(
367
378
  p.close()
368
379
 
369
380
 
370
- def apply_on_windows_args(f: Callable[..., None], args: argparse.Namespace) -> None:
381
+ def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
371
382
  """Call apply_on_windows with arguments passed via command-line interface."""
372
383
  dataset = Dataset(UPath(args.root), args.disabled_layers)
373
384
  apply_on_windows(
@@ -413,12 +424,12 @@ class PrepareHandler:
413
424
  """
414
425
  self.dataset = dataset
415
426
 
416
- def __call__(self, windows: list[Window]) -> None:
427
+ def __call__(self, windows: list[Window]) -> PrepareDatasetWindowsSummary:
417
428
  """Prepares the windows from apply_on_windows."""
418
429
  logger.info(f"Running prepare on {len(windows)} windows")
419
430
  if self.dataset is None:
420
431
  raise ValueError("dataset not set")
421
- prepare_dataset_windows(
432
+ return prepare_dataset_windows(
422
433
  self.dataset,
423
434
  windows,
424
435
  self.force,
@@ -502,14 +513,20 @@ class IngestHandler:
502
513
 
503
514
  def __call__(
504
515
  self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
505
- ) -> None:
516
+ ) -> IngestDatasetJobsSummary:
506
517
  """Ingest the specified items.
507
518
 
508
519
  The items are computed from list of windows via IngestHandler.get_jobs.
509
520
 
510
521
  Args:
511
- jobs: list of (layer_name, item, geometries) tuples to ingest.
522
+ jobs: list of (layer_name, layer_cfg, item, geometries) tuples to ingest.
523
+
524
+ Returns:
525
+ summary of the ingest jobs operation fit for telemetry purposes.
512
526
  """
527
+ start_time = time.monotonic()
528
+ layer_summaries: list[LayerIngestSummary] = []
529
+
513
530
  logger.info(f"Running ingest for {len(jobs)} jobs")
514
531
  import gc
515
532
 
@@ -533,6 +550,8 @@ class IngestHandler:
533
550
  layer_cfg = self.dataset.layers[layer_name]
534
551
  data_source = data_source_from_config(layer_cfg, self.dataset.path)
535
552
 
553
+ attempts_counter = AttemptsCounter()
554
+ ingest_counts: IngestCounts | UnknownIngestCounts
536
555
  try:
537
556
  retry(
538
557
  lambda: data_source.ingest(
@@ -544,18 +563,47 @@ class IngestHandler:
544
563
  ),
545
564
  retry_max_attempts=self.retry_max_attempts,
546
565
  retry_backoff=self.retry_backoff,
566
+ attempts_counter=attempts_counter,
567
+ )
568
+ ingest_counts = IngestCounts(
569
+ items_ingested=len(items_and_geometries),
570
+ geometries_ingested=sum(
571
+ len(geometries) for _, geometries in items_and_geometries
572
+ ),
547
573
  )
548
574
  except Exception as e:
549
575
  if not self.ignore_errors:
550
576
  raise
551
577
 
578
+ ingest_counts = UnknownIngestCounts(
579
+ items_attempted=len(items_and_geometries),
580
+ geometries_attempted=sum(
581
+ len(geometries) for _, geometries in items_and_geometries
582
+ ),
583
+ )
552
584
  logger.error(
553
585
  "warning: got error while ingesting "
554
586
  + f"{len(items_and_geometries)} items: {e}"
555
587
  )
556
588
 
589
+ layer_summaries.append(
590
+ LayerIngestSummary(
591
+ layer_name=layer_name,
592
+ data_source_name=getattr(layer_cfg.data_source, "name", "N/A"),
593
+ duration_seconds=time.monotonic() - start_time,
594
+ ingest_counts=ingest_counts,
595
+ ingest_attempts=attempts_counter.value,
596
+ )
597
+ )
598
+
557
599
  gc.collect()
558
600
 
601
+ return IngestDatasetJobsSummary(
602
+ duration_seconds=time.monotonic() - start_time,
603
+ num_jobs=len(jobs),
604
+ layer_summaries=layer_summaries,
605
+ )
606
+
559
607
  def _load_layer_data_for_windows(
560
608
  self, windows: list[Window], workers: int
561
609
  ) -> list[tuple[Window, dict[str, WindowLayerData]]]:
@@ -686,13 +734,16 @@ class MaterializeHandler:
686
734
  """
687
735
  self.dataset = dataset
688
736
 
689
- def __call__(self, windows: list[Window]) -> None:
737
+ def __call__(
738
+ self, windows: list[Window]
739
+ ) -> MaterializeDatasetWindowsSummary | ErrorOutcome:
690
740
  """Materializes the windows from apply_on_windows."""
691
741
  logger.info(f"Running Materialize with {len(windows)} windows")
742
+ start_time = time.monotonic()
692
743
  if self.dataset is None:
693
744
  raise ValueError("dataset not set")
694
745
  try:
695
- materialize_dataset_windows(
746
+ return materialize_dataset_windows(
696
747
  self.dataset,
697
748
  windows,
698
749
  retry_max_attempts=self.retry_max_attempts,
@@ -703,6 +754,7 @@ class MaterializeHandler:
703
754
  logger.error(f"Error materializing windows: {e}")
704
755
  raise
705
756
  logger.warning(f"Ignoring error while materializing windows: {e}")
757
+ return ErrorOutcome(duration_seconds=time.monotonic() - start_time)
706
758
 
707
759
 
708
760
  @register_handler("dataset", "materialize")
@@ -15,6 +15,7 @@ from huggingface_hub import hf_hub_download
15
15
  # from claymodel.module import ClayMAEModule
16
16
  from terratorch.models.backbones.clay_v15.module import ClayMAEModule
17
17
 
18
+ from rslearn.train.transforms.normalize import Normalize
18
19
  from rslearn.train.transforms.transform import Transform
19
20
 
20
21
 
@@ -163,13 +164,36 @@ class Clay(torch.nn.Module):
163
164
 
164
165
 
165
166
  class ClayNormalize(Transform):
166
- """Normalize inputs using Clay metadata."""
167
+ """Normalize inputs using Clay metadata.
168
+
169
+ For Sentinel-1, the intensities should be converted to decibels.
170
+ """
167
171
 
168
172
  def __init__(self, metadata_path: str = CLAY_METADATA_PATH) -> None:
169
173
  """Initialize ClayNormalize."""
170
174
  super().__init__()
171
175
  with open(metadata_path) as f:
172
- self.metadata = yaml.safe_load(f)
176
+ metadata = yaml.safe_load(f)
177
+ normalizers = {}
178
+ for modality in CLAY_MODALITIES:
179
+ if modality not in metadata:
180
+ continue
181
+ modality_metadata = metadata[modality]
182
+ means = [
183
+ modality_metadata["bands"]["mean"][b]
184
+ for b in modality_metadata["band_order"]
185
+ ]
186
+ stds = [
187
+ modality_metadata["bands"]["std"][b]
188
+ for b in modality_metadata["band_order"]
189
+ ]
190
+ normalizers[modality] = Normalize(
191
+ mean=means,
192
+ std=stds,
193
+ selectors=[modality],
194
+ num_bands=len(means),
195
+ )
196
+ self.normalizers = torch.nn.ModuleDict(normalizers)
173
197
 
174
198
  def apply_image(
175
199
  self, image: torch.Tensor, means: list[float], stds: list[float]
@@ -188,17 +212,8 @@ class ClayNormalize(Transform):
188
212
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
189
213
  ) -> tuple[dict[str, Any], dict[str, Any]]:
190
214
  """Normalize the specified image with Clay normalization."""
191
- for modality in CLAY_MODALITIES:
192
- if modality not in input_dict or modality not in self.metadata:
215
+ for modality, normalizer in self.normalizers.items():
216
+ if modality not in input_dict:
193
217
  continue
194
- modality_metadata = self.metadata[modality]
195
- means = [
196
- modality_metadata["bands"]["mean"][b]
197
- for b in modality_metadata["band_order"]
198
- ]
199
- stds = [
200
- modality_metadata["bands"]["std"][b]
201
- for b in modality_metadata["band_order"]
202
- ]
203
- input_dict[modality] = self.apply_image(input_dict[modality], means, stds)
218
+ input_dict, target_dict = normalizer(input_dict, target_dict)
204
219
  return input_dict, target_dict
@@ -3,11 +3,12 @@
3
3
  import logging
4
4
  import math
5
5
  from enum import Enum
6
+ from pathlib import Path
6
7
 
7
8
  import torch
8
9
  import torch.nn.functional as F
9
10
  from einops import rearrange
10
- from upath import UPath
11
+ from huggingface_hub import hf_hub_download
11
12
 
12
13
  from .copernicusfm_src.model_vit import vit_base_patch16
13
14
 
@@ -64,6 +65,10 @@ MODALITY_TO_WAVELENGTH_BANDWIDTHS: dict[str, dict[str, list]] = {
64
65
  },
65
66
  }
66
67
 
68
+ HF_REPO_ID = "wangyi111/Copernicus-FM"
69
+ HF_REPO_REVISION = "e1db406d517a122c8373802e1c130c5fc4789f84"
70
+ HF_FILENAME = "CopernicusFM_ViT_base_varlang_e100.pth"
71
+
67
72
 
68
73
  class CopernicusFM(torch.nn.Module):
69
74
  """Wrapper for Copernicus FM to ingest Masked Helios Sample."""
@@ -80,44 +85,51 @@ class CopernicusFM(torch.nn.Module):
80
85
  def __init__(
81
86
  self,
82
87
  band_order: dict[str, list[str]],
83
- load_directory: str | None,
88
+ cache_dir: str | Path | None = None,
84
89
  ) -> None:
85
90
  """Initialize the Copernicus FM wrapper.
86
91
 
87
92
  Args:
88
- band_order: The band order for each modality
89
- load_directory: The directory to load from, if None no weights are loaded
93
+ band_order: The band order for each modality that will be used. The bands
94
+ can be provided in any order, and any subset can be used.
95
+ cache_dir: The directory to cache the weights. If None, a default directory
96
+ managed by huggingface_hub is used. The weights are downloaded from
97
+ Hugging Face (https://huggingface.co/wangyi111/Copernicus-FM).
90
98
  """
91
99
  super().__init__()
92
100
 
101
+ # Make sure all keys in band_order are in supported_modalities.
102
+ for modality_name in band_order.keys():
103
+ if modality_name in self.supported_modalities:
104
+ continue
105
+ raise ValueError(
106
+ f"band_order contains unsupported modality {modality_name}"
107
+ )
108
+
93
109
  # global_pool=True so that we initialize the fc_norm layer
94
- self.band_order = band_order
95
110
  self.model = vit_base_patch16(num_classes=10, global_pool=True)
96
- if load_directory is not None:
97
- check_point = torch.load(
98
- UPath(load_directory) / "CopernicusFM_ViT_base_varlang_e100.pth",
99
- weights_only=True,
100
- )
101
- if "model" in check_point:
102
- state_dict = check_point["model"]
103
- else:
104
- state_dict = check_point
105
- self.model.load_state_dict(state_dict, strict=False)
106
-
107
- # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrage it so that it has the same
108
- # ordering as the Helios band orders, defined by Modality.band_order
111
+
112
+ # Load weights, downloading if needed.
113
+ local_fname = hf_hub_download(
114
+ repo_id=HF_REPO_ID,
115
+ revision=HF_REPO_REVISION,
116
+ filename=HF_FILENAME,
117
+ local_dir=cache_dir,
118
+ ) # nosec
119
+ state_dict = torch.load(local_fname, weights_only=True)
120
+ self.model.load_state_dict(state_dict, strict=False)
121
+
122
+ # take MODALITY_TO_WAVELENGTH_BANDWIDTHS and rearrange it so that it has the same
123
+ # ordering as the user-provided band order.
109
124
  self.modality_to_wavelength_bandwidths = {}
110
125
  for modality in self.supported_modalities:
126
+ if modality not in band_order:
127
+ continue
128
+
111
129
  wavelength_bandwidths = MODALITY_TO_WAVELENGTH_BANDWIDTHS[modality]
112
130
  wavelengths = []
113
131
  bandwidths = []
114
- modality_band_order = self.band_order.get(modality, None)
115
- if modality_band_order is None:
116
- logger.warning(
117
- f"Band order for modality {modality} not found in band_order dictionary, unable to use this modality unless specified"
118
- )
119
- continue
120
- for b in modality_band_order:
132
+ for b in band_order[modality]:
121
133
  cfm_idx = wavelength_bandwidths["band_names"].index(b)
122
134
  wavelengths.append(wavelength_bandwidths["band_wavelengths"][cfm_idx])
123
135
  bandwidths.append(wavelength_bandwidths["band_bandwidths"][cfm_idx])
@@ -0,0 +1,166 @@
1
+ """DinoV3 model."""
2
+
3
+ from enum import StrEnum
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torchvision
9
+ from einops import rearrange
10
+
11
+ from rslearn.train.transforms.normalize import Normalize
12
+ from rslearn.train.transforms.transform import Transform
13
+
14
+
15
+ class DinoV3Models(StrEnum):
16
+ """Names for different DinoV3 images on torch hub."""
17
+
18
+ SMALL_WEB = "dinov3_vits16"
19
+ SMALL_PLUS_WEB = "dinov3_vits16plus"
20
+ BASE_WEB = "dinov3_vitb16"
21
+ LARGE_WEB = "dinov3_vitl16"
22
+ HUGE_PLUS_WEB = "dinov3_vith16plus"
23
+ FULL_7B_WEB = "dinov3_vit7b16"
24
+ LARGE_SATELLITE = "dinov3_vitl16_sat"
25
+ FULL_7B_SATELLITE = "dinov3_vit7b16_sat"
26
+
27
+
28
+ DINOV3_PTHS: dict[str, str] = {
29
+ DinoV3Models.LARGE_SATELLITE: "dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth",
30
+ DinoV3Models.FULL_7B_SATELLITE: "dinov3_vit7b16_pretrain_sat493m-a6675841.pth",
31
+ DinoV3Models.BASE_WEB: "dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth",
32
+ DinoV3Models.LARGE_WEB: "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth",
33
+ DinoV3Models.HUGE_PLUS_WEB: "dinov3_vith16plus_pretrain_lvd1689m-7c1da9a5.pth",
34
+ DinoV3Models.FULL_7B_WEB: "dinov3_vit7b16_pretrain_lvd1689m-a955f4.pth",
35
+ }
36
+
37
+
38
+ class DinoV3(torch.nn.Module):
39
+ """DinoV3 Backbones.
40
+
41
+ Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
42
+ See https://github.com/facebookresearch/dinov3?tab=readme-ov-file#pretrained-models
43
+
44
+ Only takes RGB as input. Expects normalized data (use the below normalizer).
45
+
46
+ Uses patch size 16. The input is resized to 256x256; when applying DinoV3 on
47
+ segmentation or detection tasks with inputs larger than 256x256, it may be best to
48
+ train and predict on 256x256 crops (using SplitConfig.patch_size argument).
49
+ """
50
+
51
+ image_size: int = 256
52
+ patch_size: int = 16
53
+ output_dim: int = 1024
54
+
55
+ def _load_model(self, size: str, checkpoint_dir: str | None) -> torch.nn.Module:
56
+ model_name = size.replace("_sat", "")
57
+ if checkpoint_dir is not None:
58
+ weights = str(Path(checkpoint_dir) / DINOV3_PTHS[size])
59
+ return torch.hub.load(
60
+ "facebookresearch/dinov3",
61
+ model_name,
62
+ weights=weights,
63
+ ) # nosec
64
+ return torch.hub.load("facebookresearch/dinov3", model_name, pretrained=False) # nosec
65
+
66
+ def __init__(
67
+ self,
68
+ checkpoint_dir: str | None,
69
+ size: str = DinoV3Models.LARGE_SATELLITE,
70
+ use_cls_token: bool = False,
71
+ do_resizing: bool = True,
72
+ ) -> None:
73
+ """Instantiate a new DinoV3 instance.
74
+
75
+ Args:
76
+ checkpoint_dir: the local path to the pretrained weight dir. If None, we load the architecture
77
+ only (randomly initialized).
78
+ size: the model size, see class for various models.
79
+ use_cls_token: use pooled class token (for classification), otherwise returns spatial feature map.
80
+ do_resizing: whether to resize inputs to 256x256. Default true.
81
+ """
82
+ super().__init__()
83
+ self.size = size
84
+ self.checkpoint_dir = checkpoint_dir
85
+ self.use_cls_token = use_cls_token
86
+ self.do_resizing = do_resizing
87
+ self.model = self._load_model(size, checkpoint_dir)
88
+
89
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
90
+ """Forward pass for the dinov3 model.
91
+
92
+ Args:
93
+ inputs: input dicts that must include "image" key.
94
+
95
+ Returns:
96
+ List[torch.Tensor]: Single-scale feature tensors from the encoder.
97
+ """
98
+ cur = torch.stack([inp["image"] for inp in inputs], dim=0) # (B, C, H, W)
99
+
100
+ if self.do_resizing and (
101
+ cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
102
+ ):
103
+ cur = torchvision.transforms.functional.resize(
104
+ cur,
105
+ [self.image_size, self.image_size],
106
+ )
107
+
108
+ if self.use_cls_token:
109
+ features = self.model(cur)
110
+ else:
111
+ features = self.model.forward_features(cur)["x_norm_patchtokens"]
112
+ batch_size, num_patches, _ = features.shape
113
+ height, width = int(num_patches**0.5), int(num_patches**0.5)
114
+ features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
115
+
116
+ return [features]
117
+
118
+ def get_backbone_channels(self) -> list:
119
+ """Returns the output channels of this model when used as a backbone.
120
+
121
+ The output channels is a list of (downsample_factor, depth) that corresponds
122
+ to the feature maps that the backbone returns. For example, an element [2, 32]
123
+ indicates that the corresponding feature map is 1/2 the input resolution and
124
+ has 32 channels.
125
+ """
126
+ return [(self.patch_size, self.output_dim)]
127
+
128
+
129
+ class DinoV3Normalize(Transform):
130
+ """Normalize inputs using DinoV3 normalization.
131
+
132
+ Normalize "image" key in input according to Dino statistics from pretraining. Satellite pretraining has slightly different normalizing than the base image model so set 'satellite' depending on what pretrained model you are using.
133
+
134
+ Input "image" should be RGB-like image between 0-255.
135
+ """
136
+
137
+ def __init__(self, satellite: bool = True):
138
+ """Initialize a new DinoV3Normalize."""
139
+ super().__init__()
140
+ self.satellite = satellite
141
+ if satellite:
142
+ mean = [0.430, 0.411, 0.296]
143
+ std = [0.213, 0.156, 0.143]
144
+ else:
145
+ mean = [0.485, 0.456, 0.406]
146
+ std = [0.229, 0.224, 0.225]
147
+
148
+ self.normalize = Normalize(
149
+ [value * 255 for value in mean],
150
+ [value * 255 for value in std],
151
+ num_bands=3,
152
+ )
153
+
154
+ def forward(
155
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
156
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
157
+ """Normalize the specified image with DinoV3 normalization.
158
+
159
+ Args:
160
+ input_dict: the input dictionary.
161
+ target_dict: the target dictionary.
162
+
163
+ Returns:
164
+ normalized (input_dicts, target_dicts) tuple
165
+ """
166
+ return self.normalize(input_dict, target_dict)