rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,20 @@
|
|
|
1
1
|
"""FreezeUnfreeze callback."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
3
8
|
import torch
|
|
4
9
|
from lightning.pytorch import LightningModule
|
|
5
10
|
from lightning.pytorch.callbacks import BaseFinetuning
|
|
11
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
6
12
|
from torch.optim.optimizer import Optimizer
|
|
7
13
|
|
|
14
|
+
from rslearn.log_utils import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
8
18
|
|
|
9
19
|
class FreezeUnfreeze(BaseFinetuning):
|
|
10
20
|
"""Freezes a module and optionally unfreezes it after a number of epochs."""
|
|
@@ -14,7 +24,7 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
14
24
|
module_selector: list[str | int],
|
|
15
25
|
unfreeze_at_epoch: int | None = None,
|
|
16
26
|
unfreeze_lr_factor: float = 1,
|
|
17
|
-
):
|
|
27
|
+
) -> None:
|
|
18
28
|
"""Creates a new FreezeUnfreeze.
|
|
19
29
|
|
|
20
30
|
Args:
|
|
@@ -30,6 +40,8 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
30
40
|
self.module_selector = module_selector
|
|
31
41
|
self.unfreeze_at_epoch = unfreeze_at_epoch
|
|
32
42
|
self.unfreeze_lr_factor = unfreeze_lr_factor
|
|
43
|
+
if unfreeze_at_epoch == 0:
|
|
44
|
+
raise ValueError("unfreeze_at_epoch cannot be 0")
|
|
33
45
|
|
|
34
46
|
def _get_target_module(self, pl_module: LightningModule) -> torch.nn.Module:
|
|
35
47
|
target_module = pl_module
|
|
@@ -40,18 +52,18 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
40
52
|
target_module = getattr(target_module, k)
|
|
41
53
|
return target_module
|
|
42
54
|
|
|
43
|
-
def freeze_before_training(self, pl_module: LightningModule):
|
|
55
|
+
def freeze_before_training(self, pl_module: LightningModule) -> None:
|
|
44
56
|
"""Freeze the model at the beginning of training.
|
|
45
57
|
|
|
46
58
|
Args:
|
|
47
59
|
pl_module: the LightningModule.
|
|
48
60
|
"""
|
|
49
|
-
|
|
61
|
+
logger.info(f"freezing model at {self.module_selector}")
|
|
50
62
|
self.freeze(self._get_target_module(pl_module))
|
|
51
63
|
|
|
52
64
|
def finetune_function(
|
|
53
65
|
self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
|
|
54
|
-
):
|
|
66
|
+
) -> None:
|
|
55
67
|
"""Check whether we should unfreeze the model on each epoch.
|
|
56
68
|
|
|
57
69
|
Args:
|
|
@@ -61,19 +73,338 @@ class FreezeUnfreeze(BaseFinetuning):
|
|
|
61
73
|
"""
|
|
62
74
|
if self.unfreeze_at_epoch is None:
|
|
63
75
|
return
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
76
|
+
elif current_epoch == self.unfreeze_at_epoch:
|
|
77
|
+
logger.info(
|
|
78
|
+
f"unfreezing model at {self.module_selector} since we are on epoch {current_epoch}"
|
|
79
|
+
)
|
|
80
|
+
self.unfreeze_and_add_param_group(
|
|
81
|
+
modules=self._get_target_module(pl_module),
|
|
82
|
+
optimizer=optimizer,
|
|
83
|
+
initial_denom_lr=self.unfreeze_lr_factor,
|
|
84
|
+
)
|
|
85
|
+
if "scheduler" in pl_module.schedulers:
|
|
86
|
+
scheduler = pl_module.schedulers["scheduler"]
|
|
87
|
+
if isinstance(scheduler, ReduceLROnPlateau):
|
|
88
|
+
while len(scheduler.min_lrs) < len(optimizer.param_groups):
|
|
89
|
+
logger.info(
|
|
90
|
+
"appending to ReduceLROnPlateau scheduler min_lrs for unfreeze"
|
|
91
|
+
)
|
|
92
|
+
scheduler.min_lrs.append(scheduler.min_lrs[0])
|
|
93
|
+
elif current_epoch > self.unfreeze_at_epoch:
|
|
94
|
+
# always do this because overhead is minimal, and it allows restoring
|
|
95
|
+
# from a checkpoint (resuming a run) without messing up unfreezing
|
|
96
|
+
BaseFinetuning.make_trainable(self._get_target_module(pl_module))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class FTStage:
|
|
101
|
+
"""Specification for a single fine-tuning stage.
|
|
102
|
+
|
|
103
|
+
Each stage is activated when the trainer reaches a specific epoch (`at_epoch`).
|
|
104
|
+
Within that stage, modules whose **qualified name** (from `named_modules()`)
|
|
105
|
+
matches any substring in `freeze_selectors` will be frozen, except those whose
|
|
106
|
+
name matches any substring in `unfreeze_selectors`, which are forced trainable.
|
|
107
|
+
|
|
108
|
+
freeze_selectors does not carry over to other stages. That is, if you freeze module
|
|
109
|
+
A for stage 1, it will not be frozen for stage 2 unless specified again in stage 2.
|
|
110
|
+
All stages indepedently update trainability of all modules specified or unspecified.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
at_epoch: Epoch index at which to apply this stage (0-based).
|
|
114
|
+
freeze_selectors: Substrings; any module name containing any of these will
|
|
115
|
+
be frozen in this stage (unless also matched by `unfreeze_selectors`).
|
|
116
|
+
unfreeze_selectors: Substrings; any module name containing any of these
|
|
117
|
+
will be **unfrozen** (trainable) in this stage, overriding freezes.
|
|
118
|
+
unfreeze_lr_factor: When parameters become trainable and are **not yet**
|
|
119
|
+
part of the optimizer, a new param group is added with learning rate
|
|
120
|
+
`base_lr / unfreeze_lr_factor`. Use 1.0 to keep the base learning rate.
|
|
121
|
+
scale_existing_groups: If provided and not 1.0, multiply the learning rate
|
|
122
|
+
of **all existing optimizer param groups** by this factor at the moment
|
|
123
|
+
this stage is applied. Use this to calm down previously-trainable
|
|
124
|
+
parts (e.g., the head) when unfreezing deeper layers.
|
|
125
|
+
Set to ``None`` to leave existing groups unchanged.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
at_epoch: int
|
|
129
|
+
freeze_selectors: Sequence[str]
|
|
130
|
+
unfreeze_selectors: Sequence[str]
|
|
131
|
+
unfreeze_lr_factor: float = 1.0
|
|
132
|
+
scale_existing_groups: float | None = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class MultiStageFineTuning(BaseFinetuning):
|
|
136
|
+
"""Multi-stage fine-tuning with flexible name-based selection.
|
|
137
|
+
|
|
138
|
+
Behavior per stage:
|
|
139
|
+
1) Start from a **fully trainable** baseline.
|
|
140
|
+
2) Optionally **scale existing optimizer groups** via `scale_existing_groups`.
|
|
141
|
+
3) **Freeze** modules matching any `freeze_selectors`.
|
|
142
|
+
4) **Unfreeze** modules matching any `unfreeze_selectors` (overrides step 3).
|
|
143
|
+
5) For newly trainable parameters **not yet** in the optimizer, add a new
|
|
144
|
+
param group using `unfreeze_lr_factor` (lr = base_lr / factor).
|
|
145
|
+
|
|
146
|
+
Stages are applied exactly once at their `at_epoch`. The plan is recomputed
|
|
147
|
+
from scratch at each stage to keep behavior predictable on resume.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(self, stages: list[FTStage]) -> None:
|
|
151
|
+
"""Multi-stage fine-tuning with flexible name-based selection.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
stages: A sequence of stage specifications.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: If two stages specify the same `at_epoch`.
|
|
158
|
+
"""
|
|
159
|
+
super().__init__()
|
|
160
|
+
self.stages = stages
|
|
161
|
+
|
|
162
|
+
# Validate uniqueness of epochs and sort stages.
|
|
163
|
+
seen: set[int] = set()
|
|
164
|
+
for st in self.stages:
|
|
165
|
+
if st.at_epoch in seen:
|
|
166
|
+
raise ValueError(f"Duplicate at_epoch in stages: {st.at_epoch}")
|
|
167
|
+
if st.scale_existing_groups is not None and st.scale_existing_groups <= 0.0:
|
|
168
|
+
raise ValueError("scale_existing_groups, if set, must be > 0.")
|
|
169
|
+
seen.add(st.at_epoch)
|
|
170
|
+
self.stages.sort(key=lambda x: x.at_epoch)
|
|
171
|
+
|
|
172
|
+
self._applied_epochs: set[int] = set()
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _freeze_unfreeze(mod: torch.nn.Module, freeze: bool) -> None:
|
|
176
|
+
"""Freeze or unfreeze all parameters of a module without going through Lightning's flatten logic.
|
|
177
|
+
|
|
178
|
+
This is a workaround to avoid infinite recursion on ModuleDicts.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
mod: The module to freeze.
|
|
182
|
+
freeze: Whether to freeze the module.
|
|
183
|
+
"""
|
|
184
|
+
for p in mod.parameters(recurse=True):
|
|
185
|
+
p.requires_grad = not freeze
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _names_matching(names: Iterable[str], selectors: Sequence[str]) -> set[str]:
|
|
189
|
+
"""Return the subset of `names` that contains any of the given selectors.
|
|
190
|
+
|
|
191
|
+
Matching is done via simple substring checks (`sel in name`).
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
names: Iterable of qualified module names (e.g., from `named_modules()`).
|
|
195
|
+
selectors: Substrings to match against each name. Empty strings are ignored.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
A set of names from `names` that match at least one selector.
|
|
199
|
+
"""
|
|
200
|
+
if not selectors:
|
|
201
|
+
return set()
|
|
202
|
+
sels: list[str] = [s for s in selectors if s]
|
|
203
|
+
out: set[str] = set()
|
|
204
|
+
for n in names:
|
|
205
|
+
if any(sel in n for sel in sels):
|
|
206
|
+
out.add(n)
|
|
207
|
+
return out
|
|
208
|
+
|
|
209
|
+
@staticmethod
|
|
210
|
+
def _modules_by_names(
|
|
211
|
+
root: torch.nn.Module, wanted: set[str]
|
|
212
|
+
) -> list[torch.nn.Module]:
|
|
213
|
+
"""Map qualified names to module objects.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
root: The root module (e.g., your LightningModule).
|
|
217
|
+
wanted: Qualified names of submodules to retrieve.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
A list of modules corresponding to the given names that exist under `root`.
|
|
221
|
+
"""
|
|
222
|
+
if not wanted:
|
|
223
|
+
return []
|
|
224
|
+
name_to_module: dict[str, torch.nn.Module] = dict(root.named_modules())
|
|
225
|
+
return [name_to_module[n] for n in wanted if n in name_to_module]
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def _existing_param_ids(optimizer: Optimizer) -> set[int]:
|
|
229
|
+
"""Collect ids of all parameters already tracked by the optimizer.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
optimizer: The optimizer to inspect.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
A set of parameter ids already tracked by the optimizer.
|
|
236
|
+
"""
|
|
237
|
+
return {id(p) for g in optimizer.param_groups for p in g["params"]}
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _iter_module_params(modules: list[torch.nn.Module]) -> list[torch.nn.Parameter]:
|
|
241
|
+
"""Flatten parameters from a list of modules (no duplicates, trainable first).
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
modules: A list of modules to inspect.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
A list of parameters from the modules, in order of appearance.
|
|
248
|
+
"""
|
|
249
|
+
seen: set[int] = set()
|
|
250
|
+
ordered: list[torch.nn.Parameter] = []
|
|
251
|
+
for m in modules:
|
|
252
|
+
for p in m.parameters():
|
|
253
|
+
if id(p) not in seen:
|
|
254
|
+
seen.add(id(p))
|
|
255
|
+
ordered.append(p)
|
|
256
|
+
return ordered
|
|
257
|
+
|
|
258
|
+
def _apply_stage(
|
|
259
|
+
self, pl_module: LightningModule, optimizer: Optimizer, stage: FTStage
|
|
260
|
+
) -> None:
|
|
261
|
+
"""Apply a single fine-tuning stage to `pl_module` and `optimizer`.
|
|
262
|
+
|
|
263
|
+
Order of operations:
|
|
264
|
+
1) Make everything trainable (baseline).
|
|
265
|
+
2) If `scale_existing_groups` is set, multiply LR of **existing** optimizer
|
|
266
|
+
groups by this factor (and update ReduceLROnPlateau `min_lrs` if present).
|
|
267
|
+
3) Freeze modules matched by `freeze_selectors` minus `unfreeze_selectors`.
|
|
268
|
+
4) Ensure modules matched by `unfreeze_selectors` are trainable.
|
|
269
|
+
5) Add new optimizer param groups for newly-trainable modules with LR
|
|
270
|
+
scaled by `unfreeze_lr_factor`.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
pl_module: The LightningModule being trained.
|
|
274
|
+
optimizer: The optimizer currently used by the trainer.
|
|
275
|
+
stage: The stage specification to apply at the current epoch.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
None.
|
|
279
|
+
"""
|
|
280
|
+
model: torch.nn.Module = pl_module
|
|
281
|
+
all_names: list[str] = [n for n, _ in model.named_modules()]
|
|
282
|
+
|
|
283
|
+
freeze_names: set[str] = self._names_matching(all_names, stage.freeze_selectors)
|
|
284
|
+
unfreeze_names: set[str] = self._names_matching(
|
|
285
|
+
all_names, stage.unfreeze_selectors
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# 1) Baseline: everything trainable.
|
|
289
|
+
self._freeze_unfreeze(model, freeze=False)
|
|
290
|
+
|
|
291
|
+
# 2) Optionally scale existing optimizer groups (e.g., calm down the head).
|
|
292
|
+
if (
|
|
293
|
+
stage.scale_existing_groups is not None
|
|
294
|
+
and stage.scale_existing_groups != 1.0
|
|
295
|
+
):
|
|
296
|
+
factor: float = stage.scale_existing_groups
|
|
297
|
+
for g in optimizer.param_groups:
|
|
298
|
+
old_lr = float(g.get("lr", 0.0))
|
|
299
|
+
g["lr"] = old_lr * factor
|
|
300
|
+
# Keep ReduceLROnPlateau bounds consistent if present.
|
|
301
|
+
if hasattr(pl_module, "schedulers") and "scheduler" in getattr(
|
|
302
|
+
pl_module, "schedulers", {}
|
|
303
|
+
):
|
|
304
|
+
scheduler = pl_module.schedulers["scheduler"]
|
|
305
|
+
if isinstance(scheduler, ReduceLROnPlateau):
|
|
306
|
+
scheduler.min_lrs = [float(m) * factor for m in scheduler.min_lrs]
|
|
307
|
+
|
|
308
|
+
# 3) Freeze matched, except those explicitly unfreezed.
|
|
309
|
+
to_freeze: set[str] = freeze_names - unfreeze_names
|
|
310
|
+
freeze_modules: list[torch.nn.Module] = self._modules_by_names(model, to_freeze)
|
|
311
|
+
if freeze_modules:
|
|
312
|
+
to_display = sorted(list(to_freeze))
|
|
313
|
+
logger.info(
|
|
314
|
+
f"[FT stage @ epoch {stage.at_epoch}] Freezing {len(freeze_modules)} modules "
|
|
315
|
+
f"(matched: {to_display[:2] + to_display[-2:]}{'...' if len(to_freeze) > 4 else ''})"
|
|
316
|
+
)
|
|
317
|
+
for m in freeze_modules:
|
|
318
|
+
self._freeze_unfreeze(m, freeze=True)
|
|
319
|
+
|
|
320
|
+
# 4) Ensure explicitly unfreezed modules are trainable.
|
|
321
|
+
unfreeze_modules: list[torch.nn.Module] = self._modules_by_names(
|
|
322
|
+
model, unfreeze_names
|
|
68
323
|
)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
324
|
+
if unfreeze_modules:
|
|
325
|
+
to_display = sorted(list(unfreeze_names))
|
|
326
|
+
logger.info(
|
|
327
|
+
f"[FT stage @ epoch {stage.at_epoch}] Unfreezing {len(unfreeze_modules)} modules "
|
|
328
|
+
f"(matched: {to_display[:2] + to_display[-2:]}{'...' if len(unfreeze_names) > 4 else ''})"
|
|
329
|
+
)
|
|
330
|
+
for m in unfreeze_modules:
|
|
331
|
+
self._freeze_unfreeze(m, freeze=False)
|
|
332
|
+
|
|
333
|
+
# 5) Add *newly-trainable* params only (no duplicates)
|
|
334
|
+
denom: float = (
|
|
335
|
+
stage.unfreeze_lr_factor if stage.unfreeze_lr_factor != 1.0 else 1.0
|
|
336
|
+
)
|
|
337
|
+
all_params = self._iter_module_params(unfreeze_modules)
|
|
338
|
+
already = self._existing_param_ids(optimizer)
|
|
339
|
+
new_params = [
|
|
340
|
+
p for p in all_params if p.requires_grad and id(p) not in already
|
|
341
|
+
]
|
|
342
|
+
|
|
343
|
+
if new_params:
|
|
344
|
+
# Use current "base" lr (after any scale_existing_groups) as the reference
|
|
345
|
+
base_lr = float(optimizer.param_groups[0].get("lr", 0.0))
|
|
346
|
+
group_lr = base_lr / denom if denom != 0 else base_lr
|
|
347
|
+
optimizer.add_param_group({"params": new_params, "lr": group_lr})
|
|
348
|
+
|
|
349
|
+
# Extend ReduceLROnPlateau.min_lrs to match param group count
|
|
350
|
+
if hasattr(pl_module, "schedulers") and "scheduler" in getattr(
|
|
351
|
+
pl_module, "schedulers", {}
|
|
352
|
+
):
|
|
353
|
+
scheduler = pl_module.schedulers["scheduler"]
|
|
354
|
+
if isinstance(scheduler, ReduceLROnPlateau):
|
|
355
|
+
while len(scheduler.min_lrs) < len(optimizer.param_groups):
|
|
356
|
+
logger.info(
|
|
357
|
+
"Extending ReduceLROnPlateau.min_lrs for new param group"
|
|
358
|
+
)
|
|
359
|
+
scheduler.min_lrs.append(scheduler.min_lrs[0])
|
|
360
|
+
|
|
361
|
+
# Summary logging.
|
|
362
|
+
trainable, frozen = 0, 0
|
|
363
|
+
for p in model.parameters():
|
|
364
|
+
if p.requires_grad:
|
|
365
|
+
trainable += p.numel()
|
|
366
|
+
else:
|
|
367
|
+
frozen += p.numel()
|
|
368
|
+
logger.info(
|
|
369
|
+
f"[FT stage @ epoch {stage.at_epoch}] Trainable params: {trainable:,} | Frozen params: {frozen:,}"
|
|
73
370
|
)
|
|
74
371
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
372
|
+
def freeze_before_training(self, pl_module: LightningModule) -> None:
|
|
373
|
+
"""Hook: Called by Lightning before the first training epoch.
|
|
374
|
+
|
|
375
|
+
If a stage is scheduled at epoch 0, we defer its application to the first
|
|
376
|
+
call of `finetune_function` (when the optimizer is available). Otherwise,
|
|
377
|
+
we simply log that training begins with a fully trainable model.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
pl_module: The LightningModule being trained.
|
|
381
|
+
"""
|
|
382
|
+
if any(st.at_epoch == 0 for st in self.stages):
|
|
383
|
+
logger.info(
|
|
384
|
+
"Stage scheduled for epoch 0 will be applied at the first finetune_function "
|
|
385
|
+
"call when the optimizer is available."
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
logger.info("No stage at epoch 0; starting fully trainable by default.")
|
|
389
|
+
|
|
390
|
+
def finetune_function(
|
|
391
|
+
self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
|
|
392
|
+
) -> None:
|
|
393
|
+
"""Hook: Called by Lightning at each epoch to adjust trainability.
|
|
394
|
+
|
|
395
|
+
Applies any stage whose `at_epoch` equals `current_epoch` and that has not
|
|
396
|
+
yet been applied in this run. Recomputes freeze/unfreeze decisions from
|
|
397
|
+
scratch for that stage.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
pl_module: The LightningModule being trained.
|
|
401
|
+
current_epoch: The current epoch index (0-based).
|
|
402
|
+
optimizer: The optimizer currently used by the trainer.
|
|
403
|
+
"""
|
|
404
|
+
for st in self.stages:
|
|
405
|
+
if st.at_epoch == current_epoch and st.at_epoch not in self._applied_epochs:
|
|
406
|
+
logger.info(
|
|
407
|
+
f"Applying multi-stage fine-tuning plan at epoch {current_epoch}"
|
|
408
|
+
)
|
|
409
|
+
self._apply_stage(pl_module, optimizer, st)
|
|
410
|
+
self._applied_epochs.add(st.at_epoch)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Gradient logging and surgery callbacks."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.trainer import Trainer
|
|
8
|
+
from torch.nn import Module
|
|
9
|
+
from torch.optim import Optimizer
|
|
10
|
+
|
|
11
|
+
from rslearn.log_utils import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MiniPCGrad(Callback):
|
|
17
|
+
"""PCGrad from https://arxiv.org/abs/2001.06782.
|
|
18
|
+
|
|
19
|
+
This is roughly equivalent to PCGrad but uses gradient accumulation to factorize
|
|
20
|
+
projections, so we can keep gradients orthogonal in O(1) memory instead of O(n).
|
|
21
|
+
This is still quite slow, requiring an extra copy of parameter gradients in memory.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
selectors: list[str],
|
|
27
|
+
deselectors: list[str] | None = None,
|
|
28
|
+
only_monitor: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Initialize the callback.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
selectors: Prefixes for selecting which parameters to operate on.
|
|
34
|
+
deselectors: Prefixes for deselecting which parameters to operate on. Applied after selectors.
|
|
35
|
+
only_monitor: If true, only log gradients, don't clip them.
|
|
36
|
+
"""
|
|
37
|
+
self.selectors = selectors
|
|
38
|
+
self.deselectors = deselectors or []
|
|
39
|
+
self.only_monitor = only_monitor
|
|
40
|
+
self.prev_grads: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
|
|
41
|
+
|
|
42
|
+
def on_train_batch_start(
|
|
43
|
+
self, trainer: Trainer, pl_module: Module, batch: Any, batch_idx: int
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Save the dataset source each batch."""
|
|
46
|
+
self.dataset_source = batch[0][0]["dataset_source"]
|
|
47
|
+
self.batch_size = len(batch[0])
|
|
48
|
+
|
|
49
|
+
def on_before_optimizer_step(
|
|
50
|
+
self, trainer: Trainer, pl_module: Module, optimizer: Optimizer
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Reset the previous gradients."""
|
|
53
|
+
self.prev_grads = {}
|
|
54
|
+
|
|
55
|
+
def on_after_backward(self, trainer: Trainer, pl_module: Module) -> None:
|
|
56
|
+
"""Called after every loss.backward(), even under gradient accumulation.
|
|
57
|
+
|
|
58
|
+
Receives the accumulated gradients (i.e., accumulated + micro batch gradient).
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
trainer: The trainer object.
|
|
62
|
+
pl_module: The module object.
|
|
63
|
+
"""
|
|
64
|
+
prev_grad_norms = []
|
|
65
|
+
micro_grad_norms = []
|
|
66
|
+
angles = []
|
|
67
|
+
|
|
68
|
+
eps = 1e-12 # numerical stability
|
|
69
|
+
|
|
70
|
+
for name, param in pl_module.named_parameters():
|
|
71
|
+
if param.grad is None:
|
|
72
|
+
continue
|
|
73
|
+
elif all(selector not in name for selector in self.selectors) or any(
|
|
74
|
+
deselector in name for deselector in self.deselectors
|
|
75
|
+
):
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
prev_grad, prev_grad_norm = self.prev_grads[name]
|
|
80
|
+
except KeyError:
|
|
81
|
+
prev_grad = torch.zeros_like(param.grad, device=param.device)
|
|
82
|
+
prev_grad_norm = torch.tensor(0.0, device=param.device)
|
|
83
|
+
|
|
84
|
+
with torch.no_grad():
|
|
85
|
+
# current accumulated grad = prev_grad + micro_grad
|
|
86
|
+
micro_grad = param.grad - prev_grad
|
|
87
|
+
micro_grad_norm = micro_grad.norm()
|
|
88
|
+
|
|
89
|
+
micro_grad_norms.append(micro_grad_norm)
|
|
90
|
+
prev_grad_norms.append(prev_grad_norm)
|
|
91
|
+
|
|
92
|
+
# cosine of angle between micro and prev
|
|
93
|
+
denom = (micro_grad_norm * prev_grad_norm).clamp_min(eps)
|
|
94
|
+
if prev_grad_norm > 0 and micro_grad_norm > 0:
|
|
95
|
+
dot = torch.dot(micro_grad.flatten(), prev_grad.flatten())
|
|
96
|
+
cos_theta = dot / denom
|
|
97
|
+
angles.append(cos_theta)
|
|
98
|
+
|
|
99
|
+
if not self.only_monitor and dot < 0:
|
|
100
|
+
# Remove the component of micro_grad along prev_grad
|
|
101
|
+
proj_coeff = dot / (prev_grad_norm**2 + eps)
|
|
102
|
+
micro_projection = micro_grad - proj_coeff * prev_grad
|
|
103
|
+
# keep accumulated gradient as (prev + projected micro)
|
|
104
|
+
param.grad = prev_grad + micro_projection
|
|
105
|
+
logger.info(
|
|
106
|
+
f"{name} (cos={cos_theta:.4f},dot={dot:.4f},prev_grad_norm={prev_grad_norm:.4f},micro_grad_norm={micro_grad_norm:.4f})"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# store the latest accumulated gradient and its norm
|
|
110
|
+
self.prev_grads[name] = (param.grad.clone(), param.grad.norm())
|
|
111
|
+
|
|
112
|
+
log_prev_grad_norms = (
|
|
113
|
+
torch.stack(prev_grad_norms).norm()
|
|
114
|
+
if prev_grad_norms
|
|
115
|
+
else torch.tensor(0.0)
|
|
116
|
+
)
|
|
117
|
+
log_micro_grad_norms = (
|
|
118
|
+
torch.stack(micro_grad_norms).norm()
|
|
119
|
+
if micro_grad_norms
|
|
120
|
+
else torch.tensor(0.0)
|
|
121
|
+
)
|
|
122
|
+
log_angles = torch.stack(angles).mean() if angles else torch.tensor(0.0)
|
|
123
|
+
|
|
124
|
+
info = {
|
|
125
|
+
f"grads/{self.dataset_source}_prev_grad_norms": log_prev_grad_norms,
|
|
126
|
+
f"grads/{self.dataset_source}_micro_grad_norms": log_micro_grad_norms,
|
|
127
|
+
f"grads/{self.dataset_source}_angles": log_angles,
|
|
128
|
+
}
|
|
129
|
+
self.log_dict(info, on_step=True, on_epoch=False, batch_size=self.batch_size)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Parameter-efficient finetuning callbacks."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from lightning.pytorch import LightningModule
|
|
6
|
+
from lightning.pytorch.callbacks import BaseFinetuning
|
|
7
|
+
from torch.optim.optimizer import Optimizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SplitProjection(torch.nn.Module):
|
|
11
|
+
"""Split projection weights into trainable and frozen parts.
|
|
12
|
+
|
|
13
|
+
This module is used to split the projection weights into trainable and frozen parts.
|
|
14
|
+
The trainable part is used to compute the output, and the frozen part is used to
|
|
15
|
+
compute the output without gradients.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, dim: int, r: int = 8) -> None:
|
|
19
|
+
"""Initialize the SplitProjection module.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
dim: the dimension of the input and output
|
|
23
|
+
r: the number of trainable parameters
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.dim = dim
|
|
27
|
+
self.r = r
|
|
28
|
+
|
|
29
|
+
# Register indices as buffers so they move to the correct device automatically
|
|
30
|
+
indices = torch.randperm(dim)
|
|
31
|
+
self.register_buffer("trainable_inds", indices[:r])
|
|
32
|
+
self.register_buffer("frozen_inds", indices[r:])
|
|
33
|
+
|
|
34
|
+
# Create parameter modules directly
|
|
35
|
+
self.trainable_w = torch.nn.Parameter(torch.empty(dim, r), requires_grad=True)
|
|
36
|
+
self.frozen_w = torch.nn.Parameter(
|
|
37
|
+
torch.empty(dim, dim - r), requires_grad=False
|
|
38
|
+
)
|
|
39
|
+
self.trainable_b = torch.nn.Parameter(torch.empty(r), requires_grad=True)
|
|
40
|
+
self.frozen_b = torch.nn.Parameter(torch.empty(dim - r), requires_grad=False)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
"""Forward pass of the SplitProjection module.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
x: the input tensor
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
the output tensor
|
|
50
|
+
"""
|
|
51
|
+
trainable_out = F.linear(x, self.trainable_w, self.trainable_b)
|
|
52
|
+
frozen_out = F.linear(x, self.frozen_w, self.frozen_b)
|
|
53
|
+
|
|
54
|
+
output = torch.zeros(x.shape, device=x.device, dtype=trainable_out.dtype)
|
|
55
|
+
output[..., self.trainable_inds] = trainable_out # type: ignore
|
|
56
|
+
output[..., self.frozen_inds] = frozen_out # type: ignore
|
|
57
|
+
|
|
58
|
+
return output
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class APLA(BaseFinetuning):
|
|
62
|
+
"""APLA (https://arxiv.org/pdf/2503.11335v2) finetuning callback."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, r: int = 8) -> None:
|
|
65
|
+
"""Initialize the APLA finetuning callback.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
r: the number of trainable parameters
|
|
69
|
+
"""
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.r = r
|
|
72
|
+
|
|
73
|
+
def freeze_before_training(self, pl_module: LightningModule) -> None:
|
|
74
|
+
"""Freeze the model before training.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
pl_module: the LightningModule
|
|
78
|
+
"""
|
|
79
|
+
print("splitting projection weights by monkeypatching")
|
|
80
|
+
model = pl_module.model
|
|
81
|
+
self.freeze(model.encoder[0])
|
|
82
|
+
n_trainable = 0
|
|
83
|
+
for layer in model.encoder[0].model.blocks:
|
|
84
|
+
if hasattr(layer, "attn"):
|
|
85
|
+
alpa_proj = SplitProjection(layer.attn.proj.weight.shape[0], r=self.r)
|
|
86
|
+
proj_weight = layer.attn.proj.weight.data.clone()
|
|
87
|
+
proj_bias = layer.attn.proj.bias.data.clone()
|
|
88
|
+
|
|
89
|
+
alpa_proj.trainable_w.data = proj_weight[alpa_proj.trainable_inds, :]
|
|
90
|
+
alpa_proj.frozen_w.data = proj_weight[alpa_proj.frozen_inds, :]
|
|
91
|
+
|
|
92
|
+
alpa_proj.trainable_b.data = proj_bias[alpa_proj.trainable_inds]
|
|
93
|
+
alpa_proj.frozen_b.data = proj_bias[alpa_proj.frozen_inds]
|
|
94
|
+
|
|
95
|
+
alpa_proj.trainable_w.requires_grad = True
|
|
96
|
+
alpa_proj.trainable_b.requires_grad = True
|
|
97
|
+
n_trainable += (
|
|
98
|
+
alpa_proj.trainable_w.numel() + alpa_proj.trainable_b.numel()
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
layer.attn.proj = alpa_proj
|
|
102
|
+
|
|
103
|
+
print(f"n_trainable: {n_trainable / int(1e6)}M")
|
|
104
|
+
|
|
105
|
+
def finetune_function(
|
|
106
|
+
self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Do nothing here.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
pl_module: the LightningModule
|
|
112
|
+
current_epoch: the current epoch
|
|
113
|
+
optimizer: the optimizer
|
|
114
|
+
"""
|
|
115
|
+
# Maybe worth unfreezing down the line?
|
|
116
|
+
pass
|