rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
 
3
3
  import math
4
4
  import tempfile
5
+ from contextlib import nullcontext
5
6
  from enum import StrEnum
6
7
  from typing import Any, cast
7
8
 
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
63
64
 
64
65
  DEFAULT_NORMALIZER = Normalizer()
65
66
 
67
+ AUTOCAST_DTYPE_MAP = {
68
+ "bfloat16": torch.bfloat16,
69
+ "float32": torch.float32,
70
+ }
71
+
66
72
 
67
73
  class GalileoModel(nn.Module):
68
74
  """Galileo backbones."""
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
85
91
  size: GalileoSize,
86
92
  patch_size: int = 4,
87
93
  pretrained_path: str | UPath | None = None,
94
+ autocast_dtype: str | None = "bfloat16",
88
95
  ) -> None:
89
96
  """Initialize the Galileo model.
90
97
 
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
93
100
  patch_size: The patch size to use.
94
101
  pretrained_path: the local path to the pretrained weights. Otherwise it is
95
102
  downloaded and cached in temp directory.
103
+ autocast_dtype: which dtype to use for autocasting, or set None to disable.
96
104
  """
97
105
  super().__init__()
98
106
  if pretrained_path is None:
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
128
136
  idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
129
137
  ]
130
138
 
139
+ self.size = size
131
140
  self.patch_size = patch_size
132
141
 
142
+ if autocast_dtype is not None:
143
+ self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
144
+ else:
145
+ self.autocast_dtype = None
146
+
133
147
  @staticmethod
134
148
  def to_cartesian(
135
149
  lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
484
498
  patch_size = h
485
499
  else:
486
500
  patch_size = self.patch_size
487
- outputs = self.model(
488
- s_t_x=galileo_input.s_t_x,
489
- s_t_m=galileo_input.s_t_m,
490
- sp_x=galileo_input.sp_x,
491
- sp_m=galileo_input.sp_m,
492
- t_x=galileo_input.t_x,
493
- t_m=galileo_input.t_m,
494
- st_x=galileo_input.st_x,
495
- st_m=galileo_input.st_m,
496
- months=galileo_input.months,
497
- patch_size=patch_size,
498
- )
501
+
502
+ # Decide context based on self.autocast_dtype.
503
+ device = galileo_input.s_t_x.device
504
+ if self.autocast_dtype is None:
505
+ context = nullcontext()
506
+ else:
507
+ assert device is not None
508
+ context = torch.amp.autocast(
509
+ device_type=device.type, dtype=self.autocast_dtype
510
+ )
511
+
512
+ with context:
513
+ outputs = self.model(
514
+ s_t_x=galileo_input.s_t_x,
515
+ s_t_m=galileo_input.s_t_m,
516
+ sp_x=galileo_input.sp_x,
517
+ sp_m=galileo_input.sp_m,
518
+ t_x=galileo_input.t_x,
519
+ t_m=galileo_input.t_m,
520
+ st_x=galileo_input.st_x,
521
+ st_m=galileo_input.st_m,
522
+ months=galileo_input.months,
523
+ patch_size=patch_size,
524
+ )
525
+
499
526
  if h == patch_size:
500
527
  # only one spatial patch, so we can just take an average
501
528
  # of all the tokens to output b c_g 1 1
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
515
542
  "b h w c_g d -> b c_g d h w",
516
543
  ).mean(dim=1)
517
544
  ]
545
+
546
+ def get_backbone_channels(self) -> list:
547
+ """Returns the output channels of this model when used as a backbone.
548
+
549
+ The output channels is a list of (patch_size, depth) that corresponds
550
+ to the feature maps that the backbone returns.
551
+
552
+ Returns:
553
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
554
+ """
555
+ if self.size == GalileoSize.BASE:
556
+ depth = 768
557
+ elif self.model_size == GalileoSize.TINY:
558
+ depth = 192
559
+ elif self.model_size == GalileoSize.NANO:
560
+ depth = 128
561
+ else:
562
+ raise ValueError(f"Invalid model size: {self.size}")
563
+ return [(self.patch_size, depth)]
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
1469
1469
  # we take the inverse of the mask because a value
1470
1470
  # of True indicates the value *should* take part in
1471
1471
  # attention
1472
- x = blk(x=x, y=None, attn_mask=~new_m.bool())
1472
+ temp_mask = ~new_m.bool()
1473
+ if temp_mask.all():
1474
+ # if all the tokens are used in attention we can pass a None mask
1475
+ # to the attention block
1476
+ temp_mask = None
1477
+
1478
+ x = blk(x=x, y=None, attn_mask=temp_mask)
1473
1479
 
1474
1480
  if exit_ids_seq is not None:
1475
1481
  assert exited_tokens is not None
@@ -248,3 +248,14 @@ class Presto(nn.Module):
248
248
  output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
249
249
 
250
250
  return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
251
+
252
+ def get_backbone_channels(self) -> list:
253
+ """Returns the output channels of this model when used as a backbone.
254
+
255
+ The output channels is a list of (patch_size, depth) that corresponds
256
+ to the feature maps that the backbone returns.
257
+
258
+ Returns:
259
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
260
+ """
261
+ return [(1, 128)]
rslearn/models/prithvi.py CHANGED
@@ -1,57 +1,91 @@
1
1
  """Prithvi V2."""
2
2
 
3
+ import json
3
4
  import logging
4
5
  import tempfile
5
6
  import warnings
7
+ from enum import StrEnum
8
+ from pathlib import Path
6
9
  from typing import Any
7
10
 
8
11
  import numpy as np
9
12
  import torch
10
13
  import torch.nn as nn
11
- import yaml
12
14
  from einops import rearrange
13
15
  from huggingface_hub import hf_hub_download
14
16
  from timm.layers import to_2tuple
15
17
  from timm.models.vision_transformer import Block
16
18
  from torch.nn import functional as F
17
- from upath import UPath
19
+
20
+ from rslearn.train.transforms.normalize import Normalize
21
+ from rslearn.train.transforms.transform import Transform
18
22
 
19
23
  logger = logging.getLogger(__name__)
20
24
 
21
25
 
22
- # for Prithvi, true values are ["B02", "B03", "B04", "B05", "B06", "B07"]
23
- PRITHVI_MEAN = [
24
- 1087.0,
25
- 1342.0,
26
- 1433.0,
27
- 2734.0,
28
- 1958.0,
29
- 1363.0,
30
- ]
31
- PRITHVI_STD = [
32
- 2248.0,
33
- 2179.0,
34
- 2178.0,
35
- 1850.0,
36
- 1242.0,
37
- 1049.0,
38
- ]
26
+ class PrithviV2Models(StrEnum):
27
+ """Names for different Prithvi models on torch hub."""
28
+
29
+ VIT_300 = "VIT_300"
30
+ VIT_600 = "VIT_600"
31
+
32
+
33
+ MODEL_TO_HF_INFO = {
34
+ PrithviV2Models.VIT_300: {
35
+ "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
36
+ "weights": "Prithvi_EO_V2_300M.pt",
37
+ "revision": "b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
38
+ },
39
+ PrithviV2Models.VIT_600: {
40
+ "hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M",
41
+ "weights": "Prithvi_EO_V2_600M.pt",
42
+ "revision": "87f15784813828dc37aa3197a143cd4689e4d080",
43
+ },
44
+ }
45
+
46
+
47
+ HF_HUB_CONFIG_FNAME = "config.json"
48
+ DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), "rslearn_cache", "prithvi_v2")
39
49
 
40
50
 
41
- HF_HUB_ID = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M"
51
+ def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[str, Any]:
52
+ """Get the JSON config dict.
53
+
54
+ Args:
55
+ cache_dir: the directory to cache the config.json file, which will be
56
+ downloaded from HF Hub.
57
+ hf_hub_id: the HF Hub ID from which to download the config.
58
+ hf_hub_revision: The revision (commit) to download the config from.
59
+ """
60
+ cache_fname = cache_dir / HF_HUB_CONFIG_FNAME
61
+ if not cache_fname.exists():
62
+ _ = hf_hub_download(
63
+ local_dir=cache_dir,
64
+ repo_id=hf_hub_id,
65
+ filename=HF_HUB_CONFIG_FNAME,
66
+ revision=hf_hub_revision,
67
+ ) # nosec
68
+ with cache_fname.open() as f:
69
+ return json.load(f)["pretrained_cfg"]
42
70
 
43
71
 
44
72
  class PrithviV2(nn.Module):
45
73
  """An Rslearn wrapper for Prithvi 2.0."""
46
74
 
47
- input_keys = ["sentinel2"]
75
+ INPUT_KEY = "image"
48
76
 
49
- def __init__(self, pretrained_path: str | UPath | None = None, num_frames: int = 1):
50
- """Init.
77
+ def __init__(
78
+ self,
79
+ cache_dir: str | Path | None = None,
80
+ size: PrithviV2Models = PrithviV2Models.VIT_300,
81
+ num_frames: int = 1,
82
+ ):
83
+ """Create a new PrithviV2.
51
84
 
52
- Inputs:
53
- pretrained_path: The folder in which to download the prithvi config
54
- and weights. If None, it downloads to a temporary folder.
85
+ Args:
86
+ cache_dir: The local folder in which to download the prithvi config and
87
+ weights. If None, it downloads to a temporary folder.
88
+ size: the model size, see class for various models.
55
89
  num_frames: The number of input frames (timesteps). The model was trained on 3,
56
90
  but if there is just one timestamp examples use 1 (e.g.
57
91
  https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
@@ -59,35 +93,28 @@ class PrithviV2(nn.Module):
59
93
 
60
94
  """
61
95
  super().__init__()
62
- if pretrained_path is None:
63
- pretrained_path = UPath(
64
- tempfile.gettempdir(), "rslearn_cache", "prithvi_v2"
65
- )
96
+ if cache_dir is None:
97
+ cache_dir = DEFAULT_CACHE_DIR
98
+ cache_dir = Path(cache_dir)
66
99
 
67
- if not (UPath(pretrained_path) / "config.json").exists():
68
- _ = hf_hub_download(
69
- local_dir=pretrained_path,
70
- repo_id=HF_HUB_ID,
71
- filename="config.json",
72
- revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
73
- )
74
- with (UPath(pretrained_path) / "config.json").open("r") as f:
75
- config = yaml.safe_load(f)["pretrained_cfg"]
100
+ hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
101
+ revision = MODEL_TO_HF_INFO[size]["revision"]
102
+ checkpoint_fname = MODEL_TO_HF_INFO[size]["weights"]
76
103
 
104
+ config = get_config(cache_dir, hub_id, revision)
77
105
  config["num_frames"] = num_frames
78
-
79
106
  self.model = PrithviMAE(**config)
80
107
 
81
- if not (UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt").exists():
108
+ if not (cache_dir / checkpoint_fname).exists():
82
109
  _ = hf_hub_download(
83
- local_dir=pretrained_path,
84
- repo_id=HF_HUB_ID,
85
- filename="Prithvi_EO_V2_300M.pt",
86
- revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
87
- )
110
+ local_dir=cache_dir,
111
+ repo_id=hub_id,
112
+ filename=checkpoint_fname,
113
+ revision=revision,
114
+ ) # nosec
88
115
 
89
116
  state_dict = torch.load(
90
- UPath(pretrained_path) / "Prithvi_EO_V2_300M.pt",
117
+ cache_dir / checkpoint_fname,
91
118
  map_location="cpu",
92
119
  weights_only=True,
93
120
  )
@@ -125,16 +152,15 @@ class PrithviV2(nn.Module):
125
152
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
126
153
  """Compute feature maps from the Prithvi V2 backbone.
127
154
 
128
- Inputs:
129
- inputs: input dicts that must include "sentinel2"
130
- keys depending. Prithvi is designed for HLS (Harmonized Landsat-Sentinel);
131
- this naming keeps the model consistent with other rslearn models.
155
+ Args:
156
+ inputs: input dicts that must include "image" key containing HLS
157
+ (Harmonized Landsat-Sentinel) data.
132
158
 
133
159
  Returns:
134
160
  11 feature maps (one per transformer block in the Prithvi model),
135
161
  of shape [B, H/p_s, W/p_s, D=1024] where p_s=16 is the patch size.
136
162
  """
137
- x = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
163
+ x = torch.stack([inp[self.INPUT_KEY] for inp in inputs], dim=0)
138
164
  x = self._resize_data(x)
139
165
  num_timesteps = x.shape[1] // len(self.bands)
140
166
  x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
@@ -147,6 +173,67 @@ class PrithviV2(nn.Module):
147
173
  features, num_timesteps
148
174
  )
149
175
 
176
+ def get_backbone_channels(self) -> list:
177
+ """Returns the output channels of this model when used as a backbone.
178
+
179
+ The output channels is a list of (patch_size, depth) that corresponds
180
+ to the feature maps that the backbone returns.
181
+
182
+ Returns:
183
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
184
+ """
185
+ return [(1, 1024)]
186
+
187
+
188
+ class PrithviNormalize(Transform):
189
+ """Normalize inputs using Prithvi normalization.
190
+
191
+ Similar to the model, the input should be an image time series under the key
192
+ "image".
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ cache_dir: str | Path | None = None,
198
+ size: PrithviV2Models = PrithviV2Models.VIT_300,
199
+ ) -> None:
200
+ """Initialize a new PrithviNormalize.
201
+
202
+ Args:
203
+ cache_dir: the local directory to cache the config.json which contains the
204
+ means and standard deviations used in the normalization.
205
+ size: the model size, see class for various models. In this case (and
206
+ for the current hf revision), the config values (mean and std) are the
207
+ same for both the 300M and 600M model, so its safe to not set this.
208
+ """
209
+ super().__init__()
210
+ hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
211
+ revision = MODEL_TO_HF_INFO[size]["revision"]
212
+ if cache_dir is None:
213
+ cache_dir = DEFAULT_CACHE_DIR
214
+ cache_dir = Path(cache_dir)
215
+ config = get_config(cache_dir, hub_id, revision)
216
+ self.normalizer = Normalize(
217
+ mean=config["mean"],
218
+ std=config["std"],
219
+ num_bands=len(config["mean"]),
220
+ selectors=[PrithviV2.INPUT_KEY],
221
+ )
222
+
223
+ def forward(
224
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
225
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
226
+ """Apply Prithvi normalization on the image.
227
+
228
+ Args:
229
+ input_dict: the input, which must contain the "image" key.
230
+ target_dict: the target
231
+
232
+ Returns:
233
+ normalized (input_dicts, target_dicts) tuple
234
+ """
235
+ return self.normalizer(input_dict, target_dict)
236
+
150
237
 
151
238
  # Copyright (c) IBM Corp. 2024. All rights reserved.
152
239
  #
@@ -1,5 +1,22 @@
1
1
  """Model registry."""
2
2
 
3
- from class_registry import ClassRegistry
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
4
5
 
5
- Models = ClassRegistry()
6
+ _ModelT = TypeVar("_ModelT")
7
+
8
+
9
+ class _ModelRegistry(dict[str, type[Any]]):
10
+ """Registry for Model classes."""
11
+
12
+ def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
13
+ """Decorator to register a model class."""
14
+
15
+ def decorator(cls: type[_ModelT]) -> type[_ModelT]:
16
+ self[name] = cls
17
+ return cls
18
+
19
+ return decorator
20
+
21
+
22
+ Models = _ModelRegistry()
@@ -0,0 +1,45 @@
1
+ """The ResizeFeatures module."""
2
+
3
+ import torch
4
+
5
+
6
+ class ResizeFeatures(torch.nn.Module):
7
+ """Resize input features to new sizes."""
8
+
9
+ def __init__(
10
+ self,
11
+ out_sizes: list[tuple[int, int]],
12
+ mode: str = "bilinear",
13
+ ):
14
+ """Initialize a ResizeFeatures.
15
+
16
+ Args:
17
+ out_sizes: the output sizes of the feature maps. There must be one entry
18
+ for each input feature map.
19
+ mode: mode to pass to torch.nn.Upsample, e.g. "bilinear" (default) or
20
+ "nearest".
21
+ """
22
+ super().__init__()
23
+ layers = []
24
+ for size in out_sizes:
25
+ layers.append(
26
+ torch.nn.Upsample(
27
+ size=size,
28
+ mode=mode,
29
+ )
30
+ )
31
+ self.layers = torch.nn.ModuleList(layers)
32
+
33
+ def forward(
34
+ self, features: list[torch.Tensor], inputs: list[torch.Tensor]
35
+ ) -> list[torch.Tensor]:
36
+ """Resize the input feature maps to new sizes.
37
+
38
+ Args:
39
+ features: list of feature maps at different resolutions.
40
+ inputs: original inputs (ignored).
41
+
42
+ Returns:
43
+ resized feature maps
44
+ """
45
+ return [self.layers[idx](feat_map) for idx, feat_map in enumerate(features)]
@@ -20,12 +20,13 @@ class SimpleTimeSeries(torch.nn.Module):
20
20
  def __init__(
21
21
  self,
22
22
  encoder: torch.nn.Module,
23
- image_channels: int,
23
+ image_channels: int | None = None,
24
24
  op: str = "max",
25
25
  groups: list[list[int]] | None = None,
26
26
  num_layers: int | None = None,
27
27
  image_key: str = "image",
28
28
  backbone_channels: list[tuple[int, int]] | None = None,
29
+ image_keys: dict[str, int] | None = None,
29
30
  ) -> None:
30
31
  """Create a new SimpleTimeSeries.
31
32
 
@@ -48,13 +49,25 @@ class SimpleTimeSeries(torch.nn.Module):
48
49
  image_key: the key to access the images.
49
50
  backbone_channels: manually specify the backbone channels. Can be set if
50
51
  the encoder does not provide get_backbone_channels function.
52
+ image_keys: as an alternative to setting image_channels, map from the key
53
+ in input dict to the number of channels per timestep for that modality.
54
+ This way SimpleTimeSeries can be used with multimodal inputs. One of
55
+ image_channels or image_keys must be specified.
51
56
  """
57
+ if (image_channels is None and image_keys is None) or (
58
+ image_channels is not None and image_keys is not None
59
+ ):
60
+ raise ValueError(
61
+ "exactly one of image_channels and image_keys must be specified"
62
+ )
63
+
52
64
  super().__init__()
53
65
  self.encoder = encoder
54
66
  self.image_channels = image_channels
55
67
  self.op = op
56
68
  self.groups = groups
57
69
  self.image_key = image_key
70
+ self.image_keys = image_keys
58
71
 
59
72
  if backbone_channels is not None:
60
73
  out_channels = backbone_channels
@@ -144,6 +157,26 @@ class SimpleTimeSeries(torch.nn.Module):
144
157
  out_channels.append((downsample_factor, depth * self.num_groups))
145
158
  return out_channels
146
159
 
160
+ def _get_batched_images(
161
+ self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
162
+ ) -> torch.Tensor:
163
+ """Collect and reshape images across input dicts.
164
+
165
+ The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
166
+ the forward pass of a per-image (unitemporal) model.
167
+ """
168
+ images = torch.stack(
169
+ [input_dict[image_key] for input_dict in input_dicts], dim=0
170
+ )
171
+ n_batch = images.shape[0]
172
+ n_images = images.shape[1] // image_channels
173
+ n_height = images.shape[2]
174
+ n_width = images.shape[3]
175
+ batched_images = images.reshape(
176
+ n_batch * n_images, image_channels, n_height, n_width
177
+ )
178
+ return batched_images
179
+
147
180
  def forward(
148
181
  self,
149
182
  inputs: list[dict[str, Any]],
@@ -156,15 +189,37 @@ class SimpleTimeSeries(torch.nn.Module):
156
189
  """
157
190
  # First get features of each image.
158
191
  # To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
159
- images = torch.stack([inp[self.image_key] for inp in inputs], dim=0)
160
- n_batch = images.shape[0]
161
- n_images = images.shape[1] // self.image_channels
162
- n_height = images.shape[2]
163
- n_width = images.shape[3]
164
- batched_images = images.reshape(
165
- n_batch * n_images, self.image_channels, n_height, n_width
166
- )
167
- batched_inputs = [{self.image_key: image} for image in batched_images]
192
+ batched_inputs: list[dict[str, Any]] | None = None
193
+ n_batch = len(inputs)
194
+ n_images: int | None = None
195
+
196
+ if self.image_keys is not None:
197
+ for image_key, image_channels in self.image_keys.items():
198
+ batched_images = self._get_batched_images(
199
+ inputs, image_key, image_channels
200
+ )
201
+
202
+ if batched_inputs is None:
203
+ batched_inputs = [{} for _ in batched_images]
204
+ n_images = batched_images.shape[0] // n_batch
205
+ elif n_images != batched_images.shape[0] // n_batch:
206
+ raise ValueError(
207
+ "expected all modalities to have the same number of timesteps"
208
+ )
209
+
210
+ for i, image in enumerate(batched_images):
211
+ batched_inputs[i][image_key] = image
212
+
213
+ else:
214
+ assert self.image_channels is not None
215
+ batched_images = self._get_batched_images(
216
+ inputs, self.image_key, self.image_channels
217
+ )
218
+ batched_inputs = [{self.image_key: image} for image in batched_images]
219
+ n_images = batched_images.shape[0] // n_batch
220
+
221
+ assert n_images is not None
222
+
168
223
  all_features = [
169
224
  feat_map.reshape(
170
225
  n_batch,
@@ -23,13 +23,13 @@ class Upsample(torch.nn.Module):
23
23
  def forward(
24
24
  self, features: list[torch.Tensor], inputs: list[torch.Tensor]
25
25
  ) -> list[torch.Tensor]:
26
- """Compute flat output vector from multi-scale feature map.
26
+ """Upsample each feature map.
27
27
 
28
28
  Args:
29
29
  features: list of feature maps at different resolutions.
30
30
  inputs: original inputs (ignored).
31
31
 
32
32
  Returns:
33
- flat feature vector
33
+ upsampled feature maps
34
34
  """
35
35
  return [self.layer(feat_map) for feat_map in features]