rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,20 @@
1
1
  """FreezeUnfreeze callback."""
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable, Sequence
6
+ from dataclasses import dataclass
7
+
3
8
  import torch
4
9
  from lightning.pytorch import LightningModule
5
10
  from lightning.pytorch.callbacks import BaseFinetuning
11
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
6
12
  from torch.optim.optimizer import Optimizer
7
13
 
14
+ from rslearn.log_utils import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
8
18
 
9
19
  class FreezeUnfreeze(BaseFinetuning):
10
20
  """Freezes a module and optionally unfreezes it after a number of epochs."""
@@ -14,7 +24,7 @@ class FreezeUnfreeze(BaseFinetuning):
14
24
  module_selector: list[str | int],
15
25
  unfreeze_at_epoch: int | None = None,
16
26
  unfreeze_lr_factor: float = 1,
17
- ):
27
+ ) -> None:
18
28
  """Creates a new FreezeUnfreeze.
19
29
 
20
30
  Args:
@@ -30,6 +40,8 @@ class FreezeUnfreeze(BaseFinetuning):
30
40
  self.module_selector = module_selector
31
41
  self.unfreeze_at_epoch = unfreeze_at_epoch
32
42
  self.unfreeze_lr_factor = unfreeze_lr_factor
43
+ if unfreeze_at_epoch == 0:
44
+ raise ValueError("unfreeze_at_epoch cannot be 0")
33
45
 
34
46
  def _get_target_module(self, pl_module: LightningModule) -> torch.nn.Module:
35
47
  target_module = pl_module
@@ -40,18 +52,18 @@ class FreezeUnfreeze(BaseFinetuning):
40
52
  target_module = getattr(target_module, k)
41
53
  return target_module
42
54
 
43
- def freeze_before_training(self, pl_module: LightningModule):
55
+ def freeze_before_training(self, pl_module: LightningModule) -> None:
44
56
  """Freeze the model at the beginning of training.
45
57
 
46
58
  Args:
47
59
  pl_module: the LightningModule.
48
60
  """
49
- print(f"freezing model at {self.module_selector}")
61
+ logger.info(f"freezing model at {self.module_selector}")
50
62
  self.freeze(self._get_target_module(pl_module))
51
63
 
52
64
  def finetune_function(
53
65
  self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
54
- ):
66
+ ) -> None:
55
67
  """Check whether we should unfreeze the model on each epoch.
56
68
 
57
69
  Args:
@@ -61,19 +73,338 @@ class FreezeUnfreeze(BaseFinetuning):
61
73
  """
62
74
  if self.unfreeze_at_epoch is None:
63
75
  return
64
- if current_epoch != self.unfreeze_at_epoch:
65
- return
66
- print(
67
- f"unfreezing model at {self.module_selector} since we are on epoch {current_epoch}"
76
+ elif current_epoch == self.unfreeze_at_epoch:
77
+ logger.info(
78
+ f"unfreezing model at {self.module_selector} since we are on epoch {current_epoch}"
79
+ )
80
+ self.unfreeze_and_add_param_group(
81
+ modules=self._get_target_module(pl_module),
82
+ optimizer=optimizer,
83
+ initial_denom_lr=self.unfreeze_lr_factor,
84
+ )
85
+ if "scheduler" in pl_module.schedulers:
86
+ scheduler = pl_module.schedulers["scheduler"]
87
+ if isinstance(scheduler, ReduceLROnPlateau):
88
+ while len(scheduler.min_lrs) < len(optimizer.param_groups):
89
+ logger.info(
90
+ "appending to ReduceLROnPlateau scheduler min_lrs for unfreeze"
91
+ )
92
+ scheduler.min_lrs.append(scheduler.min_lrs[0])
93
+ elif current_epoch > self.unfreeze_at_epoch:
94
+ # always do this because overhead is minimal, and it allows restoring
95
+ # from a checkpoint (resuming a run) without messing up unfreezing
96
+ BaseFinetuning.make_trainable(self._get_target_module(pl_module))
97
+
98
+
99
+ @dataclass
100
+ class FTStage:
101
+ """Specification for a single fine-tuning stage.
102
+
103
+ Each stage is activated when the trainer reaches a specific epoch (`at_epoch`).
104
+ Within that stage, modules whose **qualified name** (from `named_modules()`)
105
+ matches any substring in `freeze_selectors` will be frozen, except those whose
106
+ name matches any substring in `unfreeze_selectors`, which are forced trainable.
107
+
108
+ freeze_selectors does not carry over to other stages. That is, if you freeze module
109
+ A for stage 1, it will not be frozen for stage 2 unless specified again in stage 2.
110
+ All stages indepedently update trainability of all modules specified or unspecified.
111
+
112
+ Args:
113
+ at_epoch: Epoch index at which to apply this stage (0-based).
114
+ freeze_selectors: Substrings; any module name containing any of these will
115
+ be frozen in this stage (unless also matched by `unfreeze_selectors`).
116
+ unfreeze_selectors: Substrings; any module name containing any of these
117
+ will be **unfrozen** (trainable) in this stage, overriding freezes.
118
+ unfreeze_lr_factor: When parameters become trainable and are **not yet**
119
+ part of the optimizer, a new param group is added with learning rate
120
+ `base_lr / unfreeze_lr_factor`. Use 1.0 to keep the base learning rate.
121
+ scale_existing_groups: If provided and not 1.0, multiply the learning rate
122
+ of **all existing optimizer param groups** by this factor at the moment
123
+ this stage is applied. Use this to calm down previously-trainable
124
+ parts (e.g., the head) when unfreezing deeper layers.
125
+ Set to ``None`` to leave existing groups unchanged.
126
+ """
127
+
128
+ at_epoch: int
129
+ freeze_selectors: Sequence[str]
130
+ unfreeze_selectors: Sequence[str]
131
+ unfreeze_lr_factor: float = 1.0
132
+ scale_existing_groups: float | None = None
133
+
134
+
135
+ class MultiStageFineTuning(BaseFinetuning):
136
+ """Multi-stage fine-tuning with flexible name-based selection.
137
+
138
+ Behavior per stage:
139
+ 1) Start from a **fully trainable** baseline.
140
+ 2) Optionally **scale existing optimizer groups** via `scale_existing_groups`.
141
+ 3) **Freeze** modules matching any `freeze_selectors`.
142
+ 4) **Unfreeze** modules matching any `unfreeze_selectors` (overrides step 3).
143
+ 5) For newly trainable parameters **not yet** in the optimizer, add a new
144
+ param group using `unfreeze_lr_factor` (lr = base_lr / factor).
145
+
146
+ Stages are applied exactly once at their `at_epoch`. The plan is recomputed
147
+ from scratch at each stage to keep behavior predictable on resume.
148
+ """
149
+
150
+ def __init__(self, stages: list[FTStage]) -> None:
151
+ """Multi-stage fine-tuning with flexible name-based selection.
152
+
153
+ Args:
154
+ stages: A sequence of stage specifications.
155
+
156
+ Raises:
157
+ ValueError: If two stages specify the same `at_epoch`.
158
+ """
159
+ super().__init__()
160
+ self.stages = stages
161
+
162
+ # Validate uniqueness of epochs and sort stages.
163
+ seen: set[int] = set()
164
+ for st in self.stages:
165
+ if st.at_epoch in seen:
166
+ raise ValueError(f"Duplicate at_epoch in stages: {st.at_epoch}")
167
+ if st.scale_existing_groups is not None and st.scale_existing_groups <= 0.0:
168
+ raise ValueError("scale_existing_groups, if set, must be > 0.")
169
+ seen.add(st.at_epoch)
170
+ self.stages.sort(key=lambda x: x.at_epoch)
171
+
172
+ self._applied_epochs: set[int] = set()
173
+
174
+ @staticmethod
175
+ def _freeze_unfreeze(mod: torch.nn.Module, freeze: bool) -> None:
176
+ """Freeze or unfreeze all parameters of a module without going through Lightning's flatten logic.
177
+
178
+ This is a workaround to avoid infinite recursion on ModuleDicts.
179
+
180
+ Args:
181
+ mod: The module to freeze.
182
+ freeze: Whether to freeze the module.
183
+ """
184
+ for p in mod.parameters(recurse=True):
185
+ p.requires_grad = not freeze
186
+
187
+ @staticmethod
188
+ def _names_matching(names: Iterable[str], selectors: Sequence[str]) -> set[str]:
189
+ """Return the subset of `names` that contains any of the given selectors.
190
+
191
+ Matching is done via simple substring checks (`sel in name`).
192
+
193
+ Args:
194
+ names: Iterable of qualified module names (e.g., from `named_modules()`).
195
+ selectors: Substrings to match against each name. Empty strings are ignored.
196
+
197
+ Returns:
198
+ A set of names from `names` that match at least one selector.
199
+ """
200
+ if not selectors:
201
+ return set()
202
+ sels: list[str] = [s for s in selectors if s]
203
+ out: set[str] = set()
204
+ for n in names:
205
+ if any(sel in n for sel in sels):
206
+ out.add(n)
207
+ return out
208
+
209
+ @staticmethod
210
+ def _modules_by_names(
211
+ root: torch.nn.Module, wanted: set[str]
212
+ ) -> list[torch.nn.Module]:
213
+ """Map qualified names to module objects.
214
+
215
+ Args:
216
+ root: The root module (e.g., your LightningModule).
217
+ wanted: Qualified names of submodules to retrieve.
218
+
219
+ Returns:
220
+ A list of modules corresponding to the given names that exist under `root`.
221
+ """
222
+ if not wanted:
223
+ return []
224
+ name_to_module: dict[str, torch.nn.Module] = dict(root.named_modules())
225
+ return [name_to_module[n] for n in wanted if n in name_to_module]
226
+
227
+ @staticmethod
228
+ def _existing_param_ids(optimizer: Optimizer) -> set[int]:
229
+ """Collect ids of all parameters already tracked by the optimizer.
230
+
231
+ Args:
232
+ optimizer: The optimizer to inspect.
233
+
234
+ Returns:
235
+ A set of parameter ids already tracked by the optimizer.
236
+ """
237
+ return {id(p) for g in optimizer.param_groups for p in g["params"]}
238
+
239
+ @staticmethod
240
+ def _iter_module_params(modules: list[torch.nn.Module]) -> list[torch.nn.Parameter]:
241
+ """Flatten parameters from a list of modules (no duplicates, trainable first).
242
+
243
+ Args:
244
+ modules: A list of modules to inspect.
245
+
246
+ Returns:
247
+ A list of parameters from the modules, in order of appearance.
248
+ """
249
+ seen: set[int] = set()
250
+ ordered: list[torch.nn.Parameter] = []
251
+ for m in modules:
252
+ for p in m.parameters():
253
+ if id(p) not in seen:
254
+ seen.add(id(p))
255
+ ordered.append(p)
256
+ return ordered
257
+
258
+ def _apply_stage(
259
+ self, pl_module: LightningModule, optimizer: Optimizer, stage: FTStage
260
+ ) -> None:
261
+ """Apply a single fine-tuning stage to `pl_module` and `optimizer`.
262
+
263
+ Order of operations:
264
+ 1) Make everything trainable (baseline).
265
+ 2) If `scale_existing_groups` is set, multiply LR of **existing** optimizer
266
+ groups by this factor (and update ReduceLROnPlateau `min_lrs` if present).
267
+ 3) Freeze modules matched by `freeze_selectors` minus `unfreeze_selectors`.
268
+ 4) Ensure modules matched by `unfreeze_selectors` are trainable.
269
+ 5) Add new optimizer param groups for newly-trainable modules with LR
270
+ scaled by `unfreeze_lr_factor`.
271
+
272
+ Args:
273
+ pl_module: The LightningModule being trained.
274
+ optimizer: The optimizer currently used by the trainer.
275
+ stage: The stage specification to apply at the current epoch.
276
+
277
+ Returns:
278
+ None.
279
+ """
280
+ model: torch.nn.Module = pl_module
281
+ all_names: list[str] = [n for n, _ in model.named_modules()]
282
+
283
+ freeze_names: set[str] = self._names_matching(all_names, stage.freeze_selectors)
284
+ unfreeze_names: set[str] = self._names_matching(
285
+ all_names, stage.unfreeze_selectors
286
+ )
287
+
288
+ # 1) Baseline: everything trainable.
289
+ self._freeze_unfreeze(model, freeze=False)
290
+
291
+ # 2) Optionally scale existing optimizer groups (e.g., calm down the head).
292
+ if (
293
+ stage.scale_existing_groups is not None
294
+ and stage.scale_existing_groups != 1.0
295
+ ):
296
+ factor: float = stage.scale_existing_groups
297
+ for g in optimizer.param_groups:
298
+ old_lr = float(g.get("lr", 0.0))
299
+ g["lr"] = old_lr * factor
300
+ # Keep ReduceLROnPlateau bounds consistent if present.
301
+ if hasattr(pl_module, "schedulers") and "scheduler" in getattr(
302
+ pl_module, "schedulers", {}
303
+ ):
304
+ scheduler = pl_module.schedulers["scheduler"]
305
+ if isinstance(scheduler, ReduceLROnPlateau):
306
+ scheduler.min_lrs = [float(m) * factor for m in scheduler.min_lrs]
307
+
308
+ # 3) Freeze matched, except those explicitly unfreezed.
309
+ to_freeze: set[str] = freeze_names - unfreeze_names
310
+ freeze_modules: list[torch.nn.Module] = self._modules_by_names(model, to_freeze)
311
+ if freeze_modules:
312
+ to_display = sorted(list(to_freeze))
313
+ logger.info(
314
+ f"[FT stage @ epoch {stage.at_epoch}] Freezing {len(freeze_modules)} modules "
315
+ f"(matched: {to_display[:2] + to_display[-2:]}{'...' if len(to_freeze) > 4 else ''})"
316
+ )
317
+ for m in freeze_modules:
318
+ self._freeze_unfreeze(m, freeze=True)
319
+
320
+ # 4) Ensure explicitly unfreezed modules are trainable.
321
+ unfreeze_modules: list[torch.nn.Module] = self._modules_by_names(
322
+ model, unfreeze_names
68
323
  )
69
- self.unfreeze_and_add_param_group(
70
- modules=self._get_target_module(pl_module),
71
- optimizer=optimizer,
72
- initial_denom_lr=self.unfreeze_lr_factor,
324
+ if unfreeze_modules:
325
+ to_display = sorted(list(unfreeze_names))
326
+ logger.info(
327
+ f"[FT stage @ epoch {stage.at_epoch}] Unfreezing {len(unfreeze_modules)} modules "
328
+ f"(matched: {to_display[:2] + to_display[-2:]}{'...' if len(unfreeze_names) > 4 else ''})"
329
+ )
330
+ for m in unfreeze_modules:
331
+ self._freeze_unfreeze(m, freeze=False)
332
+
333
+ # 5) Add *newly-trainable* params only (no duplicates)
334
+ denom: float = (
335
+ stage.unfreeze_lr_factor if stage.unfreeze_lr_factor != 1.0 else 1.0
336
+ )
337
+ all_params = self._iter_module_params(unfreeze_modules)
338
+ already = self._existing_param_ids(optimizer)
339
+ new_params = [
340
+ p for p in all_params if p.requires_grad and id(p) not in already
341
+ ]
342
+
343
+ if new_params:
344
+ # Use current "base" lr (after any scale_existing_groups) as the reference
345
+ base_lr = float(optimizer.param_groups[0].get("lr", 0.0))
346
+ group_lr = base_lr / denom if denom != 0 else base_lr
347
+ optimizer.add_param_group({"params": new_params, "lr": group_lr})
348
+
349
+ # Extend ReduceLROnPlateau.min_lrs to match param group count
350
+ if hasattr(pl_module, "schedulers") and "scheduler" in getattr(
351
+ pl_module, "schedulers", {}
352
+ ):
353
+ scheduler = pl_module.schedulers["scheduler"]
354
+ if isinstance(scheduler, ReduceLROnPlateau):
355
+ while len(scheduler.min_lrs) < len(optimizer.param_groups):
356
+ logger.info(
357
+ "Extending ReduceLROnPlateau.min_lrs for new param group"
358
+ )
359
+ scheduler.min_lrs.append(scheduler.min_lrs[0])
360
+
361
+ # Summary logging.
362
+ trainable, frozen = 0, 0
363
+ for p in model.parameters():
364
+ if p.requires_grad:
365
+ trainable += p.numel()
366
+ else:
367
+ frozen += p.numel()
368
+ logger.info(
369
+ f"[FT stage @ epoch {stage.at_epoch}] Trainable params: {trainable:,} | Frozen params: {frozen:,}"
73
370
  )
74
371
 
75
- if "plateau" in pl_module.schedulers:
76
- scheduler = pl_module.schedulers["plateau"]
77
- while len(scheduler.min_lrs) < len(optimizer.param_groups):
78
- print("appending to plateau scheduler min_lrs")
79
- scheduler.min_lrs.append(scheduler.min_lrs[0])
372
+ def freeze_before_training(self, pl_module: LightningModule) -> None:
373
+ """Hook: Called by Lightning before the first training epoch.
374
+
375
+ If a stage is scheduled at epoch 0, we defer its application to the first
376
+ call of `finetune_function` (when the optimizer is available). Otherwise,
377
+ we simply log that training begins with a fully trainable model.
378
+
379
+ Args:
380
+ pl_module: The LightningModule being trained.
381
+ """
382
+ if any(st.at_epoch == 0 for st in self.stages):
383
+ logger.info(
384
+ "Stage scheduled for epoch 0 will be applied at the first finetune_function "
385
+ "call when the optimizer is available."
386
+ )
387
+ else:
388
+ logger.info("No stage at epoch 0; starting fully trainable by default.")
389
+
390
+ def finetune_function(
391
+ self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
392
+ ) -> None:
393
+ """Hook: Called by Lightning at each epoch to adjust trainability.
394
+
395
+ Applies any stage whose `at_epoch` equals `current_epoch` and that has not
396
+ yet been applied in this run. Recomputes freeze/unfreeze decisions from
397
+ scratch for that stage.
398
+
399
+ Args:
400
+ pl_module: The LightningModule being trained.
401
+ current_epoch: The current epoch index (0-based).
402
+ optimizer: The optimizer currently used by the trainer.
403
+ """
404
+ for st in self.stages:
405
+ if st.at_epoch == current_epoch and st.at_epoch not in self._applied_epochs:
406
+ logger.info(
407
+ f"Applying multi-stage fine-tuning plan at epoch {current_epoch}"
408
+ )
409
+ self._apply_stage(pl_module, optimizer, st)
410
+ self._applied_epochs.add(st.at_epoch)
@@ -0,0 +1,129 @@
1
+ """Gradient logging and surgery callbacks."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.trainer import Trainer
8
+ from torch.nn import Module
9
+ from torch.optim import Optimizer
10
+
11
+ from rslearn.log_utils import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class MiniPCGrad(Callback):
17
+ """PCGrad from https://arxiv.org/abs/2001.06782.
18
+
19
+ This is roughly equivalent to PCGrad but uses gradient accumulation to factorize
20
+ projections, so we can keep gradients orthogonal in O(1) memory instead of O(n).
21
+ This is still quite slow, requiring an extra copy of parameter gradients in memory.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ selectors: list[str],
27
+ deselectors: list[str] | None = None,
28
+ only_monitor: bool = False,
29
+ ) -> None:
30
+ """Initialize the callback.
31
+
32
+ Args:
33
+ selectors: Prefixes for selecting which parameters to operate on.
34
+ deselectors: Prefixes for deselecting which parameters to operate on. Applied after selectors.
35
+ only_monitor: If true, only log gradients, don't clip them.
36
+ """
37
+ self.selectors = selectors
38
+ self.deselectors = deselectors or []
39
+ self.only_monitor = only_monitor
40
+ self.prev_grads: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
41
+
42
+ def on_train_batch_start(
43
+ self, trainer: Trainer, pl_module: Module, batch: Any, batch_idx: int
44
+ ) -> None:
45
+ """Save the dataset source each batch."""
46
+ self.dataset_source = batch[0][0]["dataset_source"]
47
+ self.batch_size = len(batch[0])
48
+
49
+ def on_before_optimizer_step(
50
+ self, trainer: Trainer, pl_module: Module, optimizer: Optimizer
51
+ ) -> None:
52
+ """Reset the previous gradients."""
53
+ self.prev_grads = {}
54
+
55
+ def on_after_backward(self, trainer: Trainer, pl_module: Module) -> None:
56
+ """Called after every loss.backward(), even under gradient accumulation.
57
+
58
+ Receives the accumulated gradients (i.e., accumulated + micro batch gradient).
59
+
60
+ Args:
61
+ trainer: The trainer object.
62
+ pl_module: The module object.
63
+ """
64
+ prev_grad_norms = []
65
+ micro_grad_norms = []
66
+ angles = []
67
+
68
+ eps = 1e-12 # numerical stability
69
+
70
+ for name, param in pl_module.named_parameters():
71
+ if param.grad is None:
72
+ continue
73
+ elif all(selector not in name for selector in self.selectors) or any(
74
+ deselector in name for deselector in self.deselectors
75
+ ):
76
+ continue
77
+
78
+ try:
79
+ prev_grad, prev_grad_norm = self.prev_grads[name]
80
+ except KeyError:
81
+ prev_grad = torch.zeros_like(param.grad, device=param.device)
82
+ prev_grad_norm = torch.tensor(0.0, device=param.device)
83
+
84
+ with torch.no_grad():
85
+ # current accumulated grad = prev_grad + micro_grad
86
+ micro_grad = param.grad - prev_grad
87
+ micro_grad_norm = micro_grad.norm()
88
+
89
+ micro_grad_norms.append(micro_grad_norm)
90
+ prev_grad_norms.append(prev_grad_norm)
91
+
92
+ # cosine of angle between micro and prev
93
+ denom = (micro_grad_norm * prev_grad_norm).clamp_min(eps)
94
+ if prev_grad_norm > 0 and micro_grad_norm > 0:
95
+ dot = torch.dot(micro_grad.flatten(), prev_grad.flatten())
96
+ cos_theta = dot / denom
97
+ angles.append(cos_theta)
98
+
99
+ if not self.only_monitor and dot < 0:
100
+ # Remove the component of micro_grad along prev_grad
101
+ proj_coeff = dot / (prev_grad_norm**2 + eps)
102
+ micro_projection = micro_grad - proj_coeff * prev_grad
103
+ # keep accumulated gradient as (prev + projected micro)
104
+ param.grad = prev_grad + micro_projection
105
+ logger.info(
106
+ f"{name} (cos={cos_theta:.4f},dot={dot:.4f},prev_grad_norm={prev_grad_norm:.4f},micro_grad_norm={micro_grad_norm:.4f})"
107
+ )
108
+
109
+ # store the latest accumulated gradient and its norm
110
+ self.prev_grads[name] = (param.grad.clone(), param.grad.norm())
111
+
112
+ log_prev_grad_norms = (
113
+ torch.stack(prev_grad_norms).norm()
114
+ if prev_grad_norms
115
+ else torch.tensor(0.0)
116
+ )
117
+ log_micro_grad_norms = (
118
+ torch.stack(micro_grad_norms).norm()
119
+ if micro_grad_norms
120
+ else torch.tensor(0.0)
121
+ )
122
+ log_angles = torch.stack(angles).mean() if angles else torch.tensor(0.0)
123
+
124
+ info = {
125
+ f"grads/{self.dataset_source}_prev_grad_norms": log_prev_grad_norms,
126
+ f"grads/{self.dataset_source}_micro_grad_norms": log_micro_grad_norms,
127
+ f"grads/{self.dataset_source}_angles": log_angles,
128
+ }
129
+ self.log_dict(info, on_step=True, on_epoch=False, batch_size=self.batch_size)
@@ -0,0 +1,116 @@
1
+ """Parameter-efficient finetuning callbacks."""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from lightning.pytorch import LightningModule
6
+ from lightning.pytorch.callbacks import BaseFinetuning
7
+ from torch.optim.optimizer import Optimizer
8
+
9
+
10
+ class SplitProjection(torch.nn.Module):
11
+ """Split projection weights into trainable and frozen parts.
12
+
13
+ This module is used to split the projection weights into trainable and frozen parts.
14
+ The trainable part is used to compute the output, and the frozen part is used to
15
+ compute the output without gradients.
16
+ """
17
+
18
+ def __init__(self, dim: int, r: int = 8) -> None:
19
+ """Initialize the SplitProjection module.
20
+
21
+ Args:
22
+ dim: the dimension of the input and output
23
+ r: the number of trainable parameters
24
+ """
25
+ super().__init__()
26
+ self.dim = dim
27
+ self.r = r
28
+
29
+ # Register indices as buffers so they move to the correct device automatically
30
+ indices = torch.randperm(dim)
31
+ self.register_buffer("trainable_inds", indices[:r])
32
+ self.register_buffer("frozen_inds", indices[r:])
33
+
34
+ # Create parameter modules directly
35
+ self.trainable_w = torch.nn.Parameter(torch.empty(dim, r), requires_grad=True)
36
+ self.frozen_w = torch.nn.Parameter(
37
+ torch.empty(dim, dim - r), requires_grad=False
38
+ )
39
+ self.trainable_b = torch.nn.Parameter(torch.empty(r), requires_grad=True)
40
+ self.frozen_b = torch.nn.Parameter(torch.empty(dim - r), requires_grad=False)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ """Forward pass of the SplitProjection module.
44
+
45
+ Args:
46
+ x: the input tensor
47
+
48
+ Returns:
49
+ the output tensor
50
+ """
51
+ trainable_out = F.linear(x, self.trainable_w, self.trainable_b)
52
+ frozen_out = F.linear(x, self.frozen_w, self.frozen_b)
53
+
54
+ output = torch.zeros(x.shape, device=x.device, dtype=trainable_out.dtype)
55
+ output[..., self.trainable_inds] = trainable_out # type: ignore
56
+ output[..., self.frozen_inds] = frozen_out # type: ignore
57
+
58
+ return output
59
+
60
+
61
+ class APLA(BaseFinetuning):
62
+ """APLA (https://arxiv.org/pdf/2503.11335v2) finetuning callback."""
63
+
64
+ def __init__(self, r: int = 8) -> None:
65
+ """Initialize the APLA finetuning callback.
66
+
67
+ Args:
68
+ r: the number of trainable parameters
69
+ """
70
+ super().__init__()
71
+ self.r = r
72
+
73
+ def freeze_before_training(self, pl_module: LightningModule) -> None:
74
+ """Freeze the model before training.
75
+
76
+ Args:
77
+ pl_module: the LightningModule
78
+ """
79
+ print("splitting projection weights by monkeypatching")
80
+ model = pl_module.model
81
+ self.freeze(model.encoder[0])
82
+ n_trainable = 0
83
+ for layer in model.encoder[0].model.blocks:
84
+ if hasattr(layer, "attn"):
85
+ alpa_proj = SplitProjection(layer.attn.proj.weight.shape[0], r=self.r)
86
+ proj_weight = layer.attn.proj.weight.data.clone()
87
+ proj_bias = layer.attn.proj.bias.data.clone()
88
+
89
+ alpa_proj.trainable_w.data = proj_weight[alpa_proj.trainable_inds, :]
90
+ alpa_proj.frozen_w.data = proj_weight[alpa_proj.frozen_inds, :]
91
+
92
+ alpa_proj.trainable_b.data = proj_bias[alpa_proj.trainable_inds]
93
+ alpa_proj.frozen_b.data = proj_bias[alpa_proj.frozen_inds]
94
+
95
+ alpa_proj.trainable_w.requires_grad = True
96
+ alpa_proj.trainable_b.requires_grad = True
97
+ n_trainable += (
98
+ alpa_proj.trainable_w.numel() + alpa_proj.trainable_b.numel()
99
+ )
100
+
101
+ layer.attn.proj = alpa_proj
102
+
103
+ print(f"n_trainable: {n_trainable / int(1e6)}M")
104
+
105
+ def finetune_function(
106
+ self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
107
+ ) -> None:
108
+ """Do nothing here.
109
+
110
+ Args:
111
+ pl_module: the LightningModule
112
+ current_epoch: the current epoch
113
+ optimizer: the optimizer
114
+ """
115
+ # Maybe worth unfreezing down the line?
116
+ pass