rslearn 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +59 -0
- rslearn/data_sources/copernicus.py +4 -4
- rslearn/data_sources/earthdaily.py +21 -1
- rslearn/data_sources/gcp_public_data.py +3 -3
- rslearn/data_sources/utils.py +1 -17
- rslearn/main.py +10 -1
- rslearn/models/trunk.py +0 -144
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +319 -0
- rslearn/train/callbacks/gradients.py +54 -34
- rslearn/train/data_module.py +70 -41
- rslearn/train/dataset.py +232 -54
- rslearn/train/lightning_module.py +4 -0
- rslearn/train/prediction_writer.py +7 -0
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/per_pixel_regression.py +259 -0
- rslearn/train/tasks/regression.py +6 -4
- rslearn/train/tasks/segmentation.py +44 -14
- rslearn/train/transforms/mask.py +69 -0
- rslearn/utils/geometry.py +8 -8
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/METADATA +3 -3
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/RECORD +26 -24
- rslearn/models/moe/distributed.py +0 -262
- rslearn/models/moe/soft.py +0 -676
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/WHEEL +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.4.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from jsonargparse import Namespace
|
|
8
|
+
from lightning.pytorch.cli import LightningArgumentParser
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def substitute_env_vars_in_string(content: str) -> str:
|
|
12
|
+
"""Substitute environment variables in a string.
|
|
13
|
+
|
|
14
|
+
Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
|
|
15
|
+
This works on raw string content before YAML parsing.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
content: The string content containing template variables
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The string with environment variables substituted
|
|
22
|
+
"""
|
|
23
|
+
pattern = r"\$\{([^}]+)\}"
|
|
24
|
+
|
|
25
|
+
def replace_variable(match_obj: re.Match[str]) -> str:
|
|
26
|
+
var_name = match_obj.group(1)
|
|
27
|
+
env_value = os.getenv(var_name, "")
|
|
28
|
+
return env_value if env_value is not None else ""
|
|
29
|
+
|
|
30
|
+
return re.sub(pattern, replace_variable, content)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RslearnArgumentParser(LightningArgumentParser):
|
|
34
|
+
"""Custom ArgumentParser that substitutes environment variables in config files.
|
|
35
|
+
|
|
36
|
+
This parser extends LightningArgumentParser to automatically substitute
|
|
37
|
+
${VAR_NAME} patterns with environment variable values before parsing
|
|
38
|
+
configuration content. This allows config files to use environment
|
|
39
|
+
variables while still passing Lightning's validation.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def parse_string(
|
|
43
|
+
self,
|
|
44
|
+
cfg_str: str,
|
|
45
|
+
cfg_path: str | os.PathLike = "",
|
|
46
|
+
ext_vars: dict | None = None,
|
|
47
|
+
env: bool | None = None,
|
|
48
|
+
defaults: bool = True,
|
|
49
|
+
with_meta: bool | None = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> Namespace:
|
|
52
|
+
"""Pre-processes string for environment variable substitution before parsing."""
|
|
53
|
+
# Substitute environment variables in the config string before parsing
|
|
54
|
+
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
55
|
+
|
|
56
|
+
# Call the parent method with the substituted config
|
|
57
|
+
return super().parse_string(
|
|
58
|
+
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
59
|
+
)
|
|
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
|
|
|
34
34
|
FloatBounds,
|
|
35
35
|
STGeometry,
|
|
36
36
|
flatten_shape,
|
|
37
|
-
|
|
37
|
+
split_shape_at_antimeridian,
|
|
38
38
|
)
|
|
39
39
|
from rslearn.utils.grid_index import GridIndex
|
|
40
40
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
|
|
|
160
160
|
# issues where the tile bounds go from -180 to 180 longitude and thus match
|
|
161
161
|
# with anything at the same latitude.
|
|
162
162
|
union_shp = shapely.unary_union(shapes)
|
|
163
|
-
split_shapes = flatten_shape(
|
|
163
|
+
split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
|
|
164
164
|
bounds_list: list[FloatBounds] = []
|
|
165
165
|
for shp in split_shapes:
|
|
166
166
|
bounds_list.append(shp.bounds)
|
|
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
|
|
|
222
222
|
"""
|
|
223
223
|
tile_index = load_sentinel2_tile_index(cache_dir)
|
|
224
224
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
225
|
-
# If the shape is a collection, it could be cutting across
|
|
225
|
+
# If the shape is a collection, it could be cutting across antimeridian.
|
|
226
226
|
# So we query each component shape separately and collect the results to avoid
|
|
227
227
|
# issues.
|
|
228
|
-
# We assume the caller has already applied
|
|
228
|
+
# We assume the caller has already applied split_at_antimeridian.
|
|
229
229
|
results = set()
|
|
230
230
|
for shp in flatten_shape(wgs84_geometry.shp):
|
|
231
231
|
for result in tile_index.query(shp.bounds):
|
|
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
82
82
|
timeout: timedelta = timedelta(seconds=10),
|
|
83
83
|
skip_items_missing_assets: bool = False,
|
|
84
84
|
cache_dir: UPath | None = None,
|
|
85
|
+
max_retries: int = 3,
|
|
86
|
+
retry_backoff_factor: float = 5.0,
|
|
85
87
|
service_name: Literal["platform"] = "platform",
|
|
86
88
|
):
|
|
87
89
|
"""Initialize a new EarthDaily instance.
|
|
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
|
|
|
99
101
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
100
102
|
If not set, there will be no cache and instead STAC requests will be
|
|
101
103
|
needed each time.
|
|
104
|
+
max_retries: the maximum number of retry attempts for HTTP requests that fail
|
|
105
|
+
due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
|
|
106
|
+
retry_backoff_factor: backoff factor for exponential retry delays between HTTP
|
|
107
|
+
request attempts. The delay between retries is calculated using the formula:
|
|
108
|
+
`(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
|
|
102
109
|
service_name: the service name, only "platform" is supported, the other
|
|
103
110
|
services "legacy" and "internal" are not supported.
|
|
104
111
|
"""
|
|
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
110
117
|
self.timeout = timeout
|
|
111
118
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
112
119
|
self.cache_dir = cache_dir
|
|
120
|
+
self.max_retries = max_retries
|
|
121
|
+
self.retry_backoff_factor = retry_backoff_factor
|
|
113
122
|
self.service_name = service_name
|
|
114
123
|
|
|
115
124
|
if cache_dir is not None:
|
|
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
139
148
|
if "cache_dir" in d:
|
|
140
149
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
141
150
|
|
|
151
|
+
if "max_retries" in d:
|
|
152
|
+
kwargs["max_retries"] = d["max_retries"]
|
|
153
|
+
|
|
154
|
+
if "retry_backoff_factor" in d:
|
|
155
|
+
kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
|
|
156
|
+
|
|
142
157
|
simple_optionals = ["query", "sort_by", "sort_ascending"]
|
|
143
158
|
for k in simple_optionals:
|
|
144
159
|
if k in d:
|
|
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
159
174
|
if self.eds_client is not None:
|
|
160
175
|
return self.eds_client, self.client, self.collection
|
|
161
176
|
|
|
162
|
-
self.eds_client = EDSClient(
|
|
177
|
+
self.eds_client = EDSClient(
|
|
178
|
+
EDSConfig(
|
|
179
|
+
max_retries=self.max_retries,
|
|
180
|
+
retry_backoff_factor=self.retry_backoff_factor,
|
|
181
|
+
)
|
|
182
|
+
)
|
|
163
183
|
|
|
164
184
|
if self.service_name == "platform":
|
|
165
185
|
self.client = self.eds_client.platform.pystac_client
|
|
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
|
26
26
|
from rslearn.log_utils import get_logger
|
|
27
27
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
28
28
|
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
29
|
-
from rslearn.utils.geometry import STGeometry, flatten_shape,
|
|
29
|
+
from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
|
|
30
30
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
31
31
|
|
|
32
32
|
from .copernicus import get_harmonize_callback, get_sentinel2_tiles
|
|
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
|
|
|
358
358
|
shp = shapely.box(*bounds)
|
|
359
359
|
sensing_time = row["sensing_time"]
|
|
360
360
|
geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
|
|
361
|
-
geometry =
|
|
361
|
+
geometry = split_at_antimeridian(geometry)
|
|
362
362
|
|
|
363
363
|
cloud_cover = float(row["cloud_cover"])
|
|
364
364
|
|
|
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
|
|
|
511
511
|
|
|
512
512
|
time_range = (product_xml.start_time, product_xml.start_time)
|
|
513
513
|
geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
|
|
514
|
-
geometry =
|
|
514
|
+
geometry = split_at_antimeridian(geometry)
|
|
515
515
|
|
|
516
516
|
# Sometimes the geometry is not valid.
|
|
517
517
|
# We just apply make_valid on it to correct issues.
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
|
|
|
256
256
|
if item_geom.is_global():
|
|
257
257
|
item_geom = geometry
|
|
258
258
|
else:
|
|
259
|
-
|
|
260
|
-
# So we first clip the item to the window bounds in the item's
|
|
261
|
-
# projection, then re-project the item to the window's projection.
|
|
262
|
-
buffered_window_geom = STGeometry(
|
|
263
|
-
geometry.projection,
|
|
264
|
-
geometry.shp.buffer(1),
|
|
265
|
-
geometry.time_range,
|
|
266
|
-
)
|
|
267
|
-
window_shp_in_item_proj = buffered_window_geom.to_projection(
|
|
268
|
-
item_geom.projection
|
|
269
|
-
).shp
|
|
270
|
-
clipped_item_geom = STGeometry(
|
|
271
|
-
item_geom.projection,
|
|
272
|
-
item_geom.shp.intersection(window_shp_in_item_proj),
|
|
273
|
-
item_geom.time_range,
|
|
274
|
-
)
|
|
275
|
-
item_geom = clipped_item_geom.to_projection(geometry.projection)
|
|
259
|
+
item_geom = item_geom.to_projection(geometry.projection)
|
|
276
260
|
item_shps.append(item_geom.shp)
|
|
277
261
|
|
|
278
262
|
if query_config.space_mode == SpaceMode.CONTAINS:
|
rslearn/main.py
CHANGED
|
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
|
13
13
|
from rasterio.crs import CRS
|
|
14
14
|
from upath import UPath
|
|
15
15
|
|
|
16
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
16
17
|
from rslearn.config import LayerConfig
|
|
17
18
|
from rslearn.const import WGS84_EPSG
|
|
18
19
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
|
|
|
779
780
|
|
|
780
781
|
|
|
781
782
|
class RslearnLightningCLI(LightningCLI):
|
|
782
|
-
"""LightningCLI that links data.tasks to model.tasks."""
|
|
783
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
783
784
|
|
|
784
785
|
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
785
786
|
"""Link data.tasks to model.tasks.
|
|
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
787
788
|
Args:
|
|
788
789
|
parser: the argument parser
|
|
789
790
|
"""
|
|
791
|
+
# Link data.tasks to model.tasks
|
|
790
792
|
parser.link_arguments(
|
|
791
793
|
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
792
794
|
)
|
|
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
815
817
|
# sampler as needed.
|
|
816
818
|
c.trainer.use_distributed_sampler = False
|
|
817
819
|
|
|
820
|
+
# For predict, make sure that return_predictions is False.
|
|
821
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
822
|
+
# high memory consumption.
|
|
823
|
+
if subcommand == "predict":
|
|
824
|
+
c.return_predictions = False
|
|
825
|
+
|
|
818
826
|
|
|
819
827
|
def model_handler() -> None:
|
|
820
828
|
"""Handler for any rslearn model X commands."""
|
|
@@ -825,6 +833,7 @@ def model_handler() -> None:
|
|
|
825
833
|
subclass_mode_model=True,
|
|
826
834
|
subclass_mode_data=True,
|
|
827
835
|
save_config_kwargs={"overwrite": True},
|
|
836
|
+
parser_class=RslearnArgumentParser,
|
|
828
837
|
)
|
|
829
838
|
|
|
830
839
|
|
rslearn/models/trunk.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Any
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from rslearn.log_utils import get_logger
|
|
9
|
-
from rslearn.models.moe.soft import SoftMoE
|
|
10
9
|
from rslearn.models.task_embedding import BaseTaskEmbedding
|
|
11
10
|
|
|
12
11
|
logger = get_logger(__name__)
|
|
@@ -135,146 +134,3 @@ class DecoderTrunk(torch.nn.Module):
|
|
|
135
134
|
"""
|
|
136
135
|
for layer in self.layers:
|
|
137
136
|
layer.apply_auxiliary_losses(trunk_out, outs)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class MoETransformer(DecoderTrunkLayer):
|
|
141
|
-
"""Transformer for decoder trunk."""
|
|
142
|
-
|
|
143
|
-
def __init__(
|
|
144
|
-
self,
|
|
145
|
-
dim: int,
|
|
146
|
-
n_layers: int,
|
|
147
|
-
n_heads: int,
|
|
148
|
-
mlp_dim: int = 512,
|
|
149
|
-
dropout: float = 0.1,
|
|
150
|
-
task_moe: bool = False,
|
|
151
|
-
disable_moe: bool = False,
|
|
152
|
-
num_experts: int = 16,
|
|
153
|
-
num_slots: int = 256,
|
|
154
|
-
expert_mult: int = 4,
|
|
155
|
-
load_balance_loss_weight: float = 0.0,
|
|
156
|
-
):
|
|
157
|
-
"""Standard ViT-style transformer, with soft MoE.
|
|
158
|
-
|
|
159
|
-
Since the point of the MoE layers is to deal with task-specific and task-shared
|
|
160
|
-
features (and not to route specific tokens), it's probably best to use max_seq_len
|
|
161
|
-
as the number of slots, and have at least one expert per task (probably more).
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
dim: dimension of the input and output
|
|
165
|
-
n_layers: number of transformer blocks
|
|
166
|
-
n_heads: number of attention heads
|
|
167
|
-
mlp_dim: dimension of the MLP
|
|
168
|
-
dropout: dropout rate
|
|
169
|
-
task_moe: if specified, compute dispatch weights given the task embedding
|
|
170
|
-
only, and not the token
|
|
171
|
-
disable_moe: if True, disable MoE
|
|
172
|
-
num_experts: number of experts in soft MoE
|
|
173
|
-
num_slots: number of slots in soft MoE
|
|
174
|
-
expert_mult: factor by which to multiply mlp_dim in the hidden layer of experts
|
|
175
|
-
load_balance_loss_weight: weight of the load balance loss
|
|
176
|
-
"""
|
|
177
|
-
super().__init__()
|
|
178
|
-
self.disable_moe = disable_moe
|
|
179
|
-
self.num_experts = num_experts
|
|
180
|
-
self.num_slots = num_slots
|
|
181
|
-
self.task_moe = task_moe
|
|
182
|
-
self.load_balance_loss_weight = load_balance_loss_weight
|
|
183
|
-
self.norm = torch.nn.LayerNorm(dim)
|
|
184
|
-
self.layers = torch.nn.ModuleList([])
|
|
185
|
-
for _ in range(n_layers):
|
|
186
|
-
mha = torch.nn.MultiheadAttention(
|
|
187
|
-
dim, n_heads, dropout=dropout, batch_first=True
|
|
188
|
-
)
|
|
189
|
-
if not disable_moe:
|
|
190
|
-
ffn = SoftMoE(
|
|
191
|
-
dim=dim,
|
|
192
|
-
num_experts=num_experts,
|
|
193
|
-
num_slots=num_slots,
|
|
194
|
-
dropout=dropout,
|
|
195
|
-
expert_mult=expert_mult,
|
|
196
|
-
)
|
|
197
|
-
else:
|
|
198
|
-
ffn = torch.nn.Sequential(
|
|
199
|
-
torch.nn.LayerNorm(dim),
|
|
200
|
-
torch.nn.Linear(dim, mlp_dim),
|
|
201
|
-
torch.nn.GELU(),
|
|
202
|
-
torch.nn.Linear(mlp_dim, dim),
|
|
203
|
-
)
|
|
204
|
-
drop = torch.nn.Dropout(dropout)
|
|
205
|
-
self.layers.append(torch.nn.ModuleList([mha, ffn, drop]))
|
|
206
|
-
|
|
207
|
-
def forward(
|
|
208
|
-
self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
|
|
209
|
-
) -> dict[str, torch.Tensor]:
|
|
210
|
-
"""Forward pass.
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
x: input tensor of shape (batch_size, seq_len, dim)
|
|
214
|
-
task_embedding: task embedding tensor of shape (batch_size, dim)
|
|
215
|
-
|
|
216
|
-
Returns:
|
|
217
|
-
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
218
|
-
and optionally "load_balance_loss", "dispatch_weights", and "combine_weights".
|
|
219
|
-
"""
|
|
220
|
-
# Forward pass through the transformer
|
|
221
|
-
infos: list[dict[str, Any]] = []
|
|
222
|
-
for mha, ffn, drop in self.layers:
|
|
223
|
-
x = mha(x, x, x)[0] + x
|
|
224
|
-
if not self.disable_moe:
|
|
225
|
-
outs = ffn(x, weight_key=task_embedding if self.task_moe else None)
|
|
226
|
-
x_ffn = outs.pop("outputs")
|
|
227
|
-
infos.append(outs)
|
|
228
|
-
x = drop(x_ffn + x)
|
|
229
|
-
else:
|
|
230
|
-
x = drop(ffn(x) + x)
|
|
231
|
-
x = self.norm(x)
|
|
232
|
-
outputs = {"outputs": x}
|
|
233
|
-
|
|
234
|
-
# If using MoE, collect expert weights and auxiliary losses
|
|
235
|
-
# Don't call detach because we will use this later on in the loss collation
|
|
236
|
-
if not self.disable_moe:
|
|
237
|
-
collated: dict[str, list[torch.Tensor]] = {
|
|
238
|
-
"load_balance_loss": [],
|
|
239
|
-
"dispatch_weights": [],
|
|
240
|
-
"combine_weights": [],
|
|
241
|
-
}
|
|
242
|
-
for info in infos:
|
|
243
|
-
for k, v in info.items():
|
|
244
|
-
if k == "dispatch_weights":
|
|
245
|
-
# each weight is [batch, seq_len, num_experts, num_slots]
|
|
246
|
-
# compute avg weight per token across slot/batch/expert
|
|
247
|
-
# NOTE: this is probably about the same across all tokens,
|
|
248
|
-
# assuming all tokens get looked at by a few experts
|
|
249
|
-
collated["dispatch_weights"].append(v.mean((0, 2, 3)))
|
|
250
|
-
|
|
251
|
-
elif k == "combine_weights":
|
|
252
|
-
# each weight is [batch, seq_len, num_experts * num_slots]
|
|
253
|
-
# compute avg weight per expert (slot group) across batch/seq
|
|
254
|
-
v = v.unflatten(-1, (self.num_experts, self.num_slots))
|
|
255
|
-
v = v.sum(-1) # [batch, seq_len, num_experts (softmax)]
|
|
256
|
-
collated["combine_weights"].append(v.mean((0, 1)))
|
|
257
|
-
|
|
258
|
-
elif k == "load_balance_loss":
|
|
259
|
-
# each load balance loss per layer is a scalar
|
|
260
|
-
collated["load_balance_loss"].append(v)
|
|
261
|
-
outputs.update(collated)
|
|
262
|
-
|
|
263
|
-
return outputs
|
|
264
|
-
|
|
265
|
-
def apply_auxiliary_losses(
|
|
266
|
-
self, trunk_out: dict[str, Any], outs: dict[str, Any]
|
|
267
|
-
) -> None:
|
|
268
|
-
"""Apply auxiliary losses in-place.
|
|
269
|
-
|
|
270
|
-
Just move the load balance loss to the loss dict, where it will eventually be summed.
|
|
271
|
-
|
|
272
|
-
Args:
|
|
273
|
-
trunk_out: The output of the trunk.
|
|
274
|
-
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
275
|
-
"""
|
|
276
|
-
if "load_balance_loss" in trunk_out and self.load_balance_loss_weight > 0.0:
|
|
277
|
-
total_aux_loss = torch.stack(trunk_out["load_balance_loss"]).mean()
|
|
278
|
-
outs["loss_dict"]["load_balance_loss"] = (
|
|
279
|
-
self.load_balance_loss_weight * total_aux_loss
|
|
280
|
-
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Callback to activate/deactivate adapter layers."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.trainer import Trainer
|
|
8
|
+
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ActivateLayers(Callback):
|
|
15
|
+
"""Activates adapter layers on a given epoch.
|
|
16
|
+
|
|
17
|
+
By default, at every epoch, every adapter layer is deactivated.
|
|
18
|
+
To activate an adapter layer, add a selector with the name of the adapter layer
|
|
19
|
+
and the epoch at which to activate it. Once an adapter layer is activated, it
|
|
20
|
+
remains active until the end of training.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, selectors: list[dict[str, Any]]) -> None:
|
|
24
|
+
"""Initialize the callback.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
selectors: List of selectors to activate.
|
|
28
|
+
Each selector is a dictionary with the following keys:
|
|
29
|
+
- "name": Substring selector of modules to activate (str).
|
|
30
|
+
- "at_epoch": The epoch at which to activate (int).
|
|
31
|
+
"""
|
|
32
|
+
self.selectors = selectors
|
|
33
|
+
|
|
34
|
+
def on_train_epoch_start(
|
|
35
|
+
self,
|
|
36
|
+
trainer: Trainer,
|
|
37
|
+
pl_module: LightningModule,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Activate adapter layers on a given epoch.
|
|
40
|
+
|
|
41
|
+
Adapter layers are activated/deactivated by setting the `active` attribute.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
trainer: The trainer object.
|
|
45
|
+
pl_module: The LightningModule object.
|
|
46
|
+
"""
|
|
47
|
+
status = {}
|
|
48
|
+
for name, module in pl_module.named_modules():
|
|
49
|
+
for selector in self.selectors:
|
|
50
|
+
if selector["name"] in name:
|
|
51
|
+
module.active = trainer.current_epoch >= selector["at_epoch"]
|
|
52
|
+
status[selector["name"]] = "active" if module.active else "inactive"
|
|
53
|
+
logger.info(f"Updated adapter status: {status}")
|