nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import fnmatch
|
|
8
|
+
import inspect
|
|
9
|
+
import itertools
|
|
10
|
+
import logging
|
|
11
|
+
import types
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
Callable,
|
|
15
|
+
Dict,
|
|
16
|
+
Iterable,
|
|
17
|
+
List,
|
|
18
|
+
Mapping,
|
|
19
|
+
Optional,
|
|
20
|
+
Set,
|
|
21
|
+
Tuple,
|
|
22
|
+
Type,
|
|
23
|
+
Union,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
import hydra
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn as nn
|
|
30
|
+
from omegaconf import DictConfig
|
|
31
|
+
from torch import Tensor
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Optimizer:
|
|
35
|
+
def __init__(self, optimizer, schedulers=None) -> None:
|
|
36
|
+
self.optimizer = optimizer
|
|
37
|
+
self.schedulers = schedulers
|
|
38
|
+
self._validate_optimizer_schedulers()
|
|
39
|
+
self.step_schedulers(0.0, 0)
|
|
40
|
+
|
|
41
|
+
def _validate_optimizer_schedulers(self):
|
|
42
|
+
if self.schedulers is None:
|
|
43
|
+
return
|
|
44
|
+
for _, set_of_schedulers in enumerate(self.schedulers):
|
|
45
|
+
for option, _ in set_of_schedulers.items():
|
|
46
|
+
assert option in self.optimizer.defaults, (
|
|
47
|
+
"Optimizer option "
|
|
48
|
+
f"{option} not found in {self.optimizer}. Valid options are "
|
|
49
|
+
f"{self.optimizer.defaults.keys()}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def step_schedulers(self, where: float, step: int) -> None:
|
|
53
|
+
if self.schedulers is None:
|
|
54
|
+
return
|
|
55
|
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
|
56
|
+
for option, scheduler in self.schedulers[i].items():
|
|
57
|
+
if "step" in inspect.signature(scheduler.__call__).parameters:
|
|
58
|
+
new_value = scheduler(step=step, where=where)
|
|
59
|
+
elif (
|
|
60
|
+
hasattr(scheduler, "scheduler")
|
|
61
|
+
and "step" in inspect.signature(scheduler.scheduler.__call__).parameters
|
|
62
|
+
):
|
|
63
|
+
# To handle ValueScaler wrappers
|
|
64
|
+
new_value = scheduler(step=step, where=where)
|
|
65
|
+
else:
|
|
66
|
+
new_value = scheduler(where)
|
|
67
|
+
param_group[option] = new_value
|
|
68
|
+
|
|
69
|
+
def step(self, where, step, closure=None):
|
|
70
|
+
self.step_schedulers(where, step)
|
|
71
|
+
return self.optimizer.step(closure)
|
|
72
|
+
|
|
73
|
+
def zero_grad(self, *args, **kwargs):
|
|
74
|
+
return self.optimizer.zero_grad(*args, **kwargs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def set_default_parameters(scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]) -> None:
|
|
78
|
+
"""Set up the "default" scheduler with the right parameters.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
scheduler_cgfs: A list of scheduler configs, where each scheduler also
|
|
82
|
+
specifies which parameters it applies to, based on the names of parameters
|
|
83
|
+
or the class of the modules. At most one scheduler is allowed to skip this
|
|
84
|
+
specification, which is used as a "default" specification for any remaining
|
|
85
|
+
parameters.
|
|
86
|
+
all_parameter_names: Names of all the parameters to consider.
|
|
87
|
+
"""
|
|
88
|
+
constraints = [
|
|
89
|
+
scheduler_cfg.parameter_names for scheduler_cfg in scheduler_cfgs if scheduler_cfg.parameter_names is not None
|
|
90
|
+
]
|
|
91
|
+
if len(constraints) == 0:
|
|
92
|
+
default_params = set(all_parameter_names)
|
|
93
|
+
else:
|
|
94
|
+
default_params = all_parameter_names - set.union(*constraints)
|
|
95
|
+
default_count = 0
|
|
96
|
+
for scheduler_cfg in scheduler_cfgs:
|
|
97
|
+
if scheduler_cfg.parameter_names is None:
|
|
98
|
+
scheduler_cfg.parameter_names = default_params
|
|
99
|
+
default_count += 1
|
|
100
|
+
assert default_count <= 1, "Only one scheduler per option can be default"
|
|
101
|
+
if default_count == 0:
|
|
102
|
+
# No default scheduler specified, add a default, but without any scheduler
|
|
103
|
+
# for that option
|
|
104
|
+
scheduler_cfgs.append({"parameter_names": default_params})
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def name_constraints_to_parameters(
|
|
108
|
+
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
|
|
109
|
+
) -> List[torch.nn.Parameter]:
|
|
110
|
+
"""Return parameters which match the intersection of parameter constraints.
|
|
111
|
+
|
|
112
|
+
Note that this returns the parameters themselves, not their names.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
param_constraints: A list, with each element being a set of allowed parameters.
|
|
116
|
+
named_parameters: Mapping from a parameter name to the parameter itself.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A list containing the parameters which overlap with _each_ constraint set from
|
|
120
|
+
param_constraints.
|
|
121
|
+
"""
|
|
122
|
+
matching_names = set.intersection(*param_constraints)
|
|
123
|
+
return [value for name, value in named_parameters.items() if name in matching_names]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def map_scheduler_cfgs_to_param_groups(
|
|
127
|
+
all_scheduler_cfgs: Iterable[List[Dict]],
|
|
128
|
+
named_parameters: Dict[str, Tensor],
|
|
129
|
+
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
|
|
130
|
+
"""Produce parameter groups corresponding to all the scheduler configs.
|
|
131
|
+
|
|
132
|
+
Takes all the scheduler configs, each of which applies to a specific optimizer
|
|
133
|
+
option (like "lr" or "weight_decay") and has a set of parameter names which it
|
|
134
|
+
applies to, and produces a final set of param groups where each param group
|
|
135
|
+
covers all the options which apply to a particular set of parameters.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
all_scheduler_cfgs: All the scheduler configs covering every option.
|
|
139
|
+
named_parameters: Mapping from a parameter name to the parameter itself.
|
|
140
|
+
Returns:
|
|
141
|
+
Tuple of lists of schedulers and param_groups, where schedulers[i]
|
|
142
|
+
applies to param_groups[i].
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
|
|
146
|
+
schedulers = []
|
|
147
|
+
param_groups = []
|
|
148
|
+
for scheduler_cfgs in scheduler_cfgs_per_param_group:
|
|
149
|
+
param_constraints = [scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs]
|
|
150
|
+
matching_parameters = name_constraints_to_parameters(param_constraints, named_parameters)
|
|
151
|
+
if len(matching_parameters) == 0: # If no overlap of parameters, skip
|
|
152
|
+
continue
|
|
153
|
+
schedulers_for_group = {
|
|
154
|
+
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
|
|
155
|
+
for scheduler_cfg in scheduler_cfgs
|
|
156
|
+
if "option" in scheduler_cfg
|
|
157
|
+
}
|
|
158
|
+
schedulers.append(schedulers_for_group)
|
|
159
|
+
param_groups.append({"params": matching_parameters})
|
|
160
|
+
return schedulers, param_groups
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
|
|
164
|
+
"""Check that the param groups are non-overlapping and cover all the parameters.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
param_groups: List of all param groups
|
|
168
|
+
model: Model to validate against. The check ensures that all the model
|
|
169
|
+
parameters are part of param_groups
|
|
170
|
+
"""
|
|
171
|
+
for pg in param_groups:
|
|
172
|
+
# no param should be repeated within a group
|
|
173
|
+
assert len(pg["params"]) == len(set(pg["params"]))
|
|
174
|
+
parameters = [set(param_group["params"]) for param_group in param_groups]
|
|
175
|
+
model_parameters = {parameter for _, parameter in model.named_parameters()}
|
|
176
|
+
for p1, p2 in itertools.permutations(parameters, 2):
|
|
177
|
+
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
|
|
178
|
+
assert set.union(*parameters) == model_parameters, (
|
|
179
|
+
"Scheduler generated param_groups must include all parameters of the model."
|
|
180
|
+
f" Found {len(set.union(*parameters))} params whereas model has"
|
|
181
|
+
f" {len(model_parameters)} params"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def unix_module_cls_pattern_to_parameter_names(
|
|
186
|
+
filter_module_cls_names: List[str],
|
|
187
|
+
module_cls_to_param_names: Dict[Type, str],
|
|
188
|
+
) -> Union[None, Set[str]]:
|
|
189
|
+
"""Returns param names which pass the filters specified in filter_module_cls_names.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
filter_module_cls_names: A list of filter strings containing class names, like
|
|
193
|
+
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
|
|
194
|
+
module_cls_to_param_names: Mapping from module classes to the parameter names
|
|
195
|
+
they contain. See `get_module_cls_to_param_names`.
|
|
196
|
+
"""
|
|
197
|
+
if filter_module_cls_names is None:
|
|
198
|
+
return set()
|
|
199
|
+
allowed_parameter_names = []
|
|
200
|
+
for module_cls_name in filter_module_cls_names:
|
|
201
|
+
module_cls = hydra.utils.get_class(module_cls_name)
|
|
202
|
+
if module_cls not in module_cls_to_param_names:
|
|
203
|
+
raise AssertionError(f"module_cls_name {module_cls_name} does not " "match any classes in the model")
|
|
204
|
+
matching_parameters = module_cls_to_param_names[module_cls]
|
|
205
|
+
assert (
|
|
206
|
+
len(matching_parameters) > 0
|
|
207
|
+
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
|
208
|
+
logging.info(f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} ")
|
|
209
|
+
allowed_parameter_names.append(matching_parameters)
|
|
210
|
+
return set.union(*allowed_parameter_names)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def unix_param_pattern_to_parameter_names(
|
|
214
|
+
filter_param_names: Optional[List[str]],
|
|
215
|
+
parameter_names: Dict[str, torch.Tensor],
|
|
216
|
+
) -> Union[None, Set[str]]:
|
|
217
|
+
"""Returns param names which pass the filters specified in filter_param_names.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
filter_param_names: A list of unix-style filter strings with optional
|
|
221
|
+
wildcards, like ["block.2.*", "block.2.linear.weight"]
|
|
222
|
+
module_cls_to_param_names: Mapping from module classes to the parameter names
|
|
223
|
+
they contain. See `get_module_cls_to_param_names`.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
if filter_param_names is None:
|
|
227
|
+
return set()
|
|
228
|
+
allowed_parameter_names = []
|
|
229
|
+
for param_name in filter_param_names:
|
|
230
|
+
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
|
231
|
+
assert len(matching_parameters) >= 1, f"param_name {param_name} does not match any parameters in the model"
|
|
232
|
+
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
|
233
|
+
allowed_parameter_names.append(matching_parameters)
|
|
234
|
+
return set.union(*allowed_parameter_names)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _unix_pattern_to_parameter_names(
|
|
238
|
+
scheduler_cfg: DictConfig,
|
|
239
|
+
parameter_names: Set[str],
|
|
240
|
+
module_cls_to_param_names: Dict[Type, str],
|
|
241
|
+
) -> Union[None, Set[str]]:
|
|
242
|
+
"""Returns param names which pass the filters specified in scheduler_cfg.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
scheduler_cfg: The config for the scheduler
|
|
246
|
+
parameter_names: The set of all parameter names which will be filtered
|
|
247
|
+
"""
|
|
248
|
+
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
|
|
249
|
+
return None
|
|
250
|
+
return unix_param_pattern_to_parameter_names(scheduler_cfg.get("param_names"), parameter_names).union(
|
|
251
|
+
unix_module_cls_pattern_to_parameter_names(scheduler_cfg.get("module_cls_names"), module_cls_to_param_names)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_module_cls_to_param_names(model: nn.Module, param_allowlist: Set[str] = None) -> Dict[Type, str]:
|
|
256
|
+
"""Produce a mapping from all the modules classes to the names of parames they own.
|
|
257
|
+
|
|
258
|
+
Only counts a parameter as part of the immediate parent module, i.e. recursive
|
|
259
|
+
parents do not count.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
model: Model to iterate over
|
|
263
|
+
param_allowlist: If specified, only these param names will be processed
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
module_cls_to_params = {}
|
|
267
|
+
for module_name, module in model.named_modules():
|
|
268
|
+
module_cls = type(module)
|
|
269
|
+
module_cls_to_params.setdefault(module_cls, set())
|
|
270
|
+
for param_name, _ in module.named_parameters(recurse=False):
|
|
271
|
+
full_param_name = get_full_parameter_name(module_name, param_name)
|
|
272
|
+
if param_allowlist is None or full_param_name in param_allowlist:
|
|
273
|
+
module_cls_to_params[module_cls].add(full_param_name)
|
|
274
|
+
return module_cls_to_params
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def construct_optimizer(
|
|
278
|
+
model: torch.nn.Module,
|
|
279
|
+
optimizer_conf: Any,
|
|
280
|
+
options_conf: Mapping[str, List] = None,
|
|
281
|
+
param_group_modifiers_conf: List[Callable] = None,
|
|
282
|
+
param_allowlist: Optional[Set[str]] = None,
|
|
283
|
+
validate_param_groups=True,
|
|
284
|
+
) -> Optimizer:
|
|
285
|
+
"""
|
|
286
|
+
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
|
|
287
|
+
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
|
|
288
|
+
Batchnorm and/or no-update 1-D parameters support, based on the config.
|
|
289
|
+
|
|
290
|
+
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
|
|
291
|
+
(LARS): https://arxiv.org/abs/1708.03888
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
model: model to perform stochastic gradient descent
|
|
295
|
+
optimization or ADAM optimization.
|
|
296
|
+
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
|
|
297
|
+
ADAM, still missing the params argument which this function provides to
|
|
298
|
+
produce the final optimizer
|
|
299
|
+
param_group_modifiers_conf: Optional user specified functions which can modify
|
|
300
|
+
the final scheduler configs before the optimizer's param groups are built
|
|
301
|
+
param_allowlist: The parameters to optimize. Parameters which are not part of
|
|
302
|
+
this allowlist will be skipped.
|
|
303
|
+
validate_param_groups: If enabled, valides that the produced param_groups don't
|
|
304
|
+
overlap and cover all the model parameters.
|
|
305
|
+
"""
|
|
306
|
+
if param_allowlist is None:
|
|
307
|
+
param_allowlist = {name for name, _ in model.named_parameters()}
|
|
308
|
+
|
|
309
|
+
named_parameters = {name: param for name, param in model.named_parameters() if name in param_allowlist}
|
|
310
|
+
|
|
311
|
+
if not options_conf:
|
|
312
|
+
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
|
|
313
|
+
return Optimizer(optimizer)
|
|
314
|
+
|
|
315
|
+
all_parameter_names = {name for name, _ in model.named_parameters() if name in param_allowlist}
|
|
316
|
+
module_cls_to_all_param_names = get_module_cls_to_param_names(model, param_allowlist)
|
|
317
|
+
|
|
318
|
+
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
|
|
319
|
+
all_scheduler_cfgs = []
|
|
320
|
+
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
|
|
321
|
+
for config in scheduler_cfgs:
|
|
322
|
+
config.option = option
|
|
323
|
+
config.parameter_names = _unix_pattern_to_parameter_names(
|
|
324
|
+
config, all_parameter_names, module_cls_to_all_param_names
|
|
325
|
+
)
|
|
326
|
+
set_default_parameters(scheduler_cfgs, all_parameter_names)
|
|
327
|
+
all_scheduler_cfgs.append(scheduler_cfgs)
|
|
328
|
+
|
|
329
|
+
if param_group_modifiers_conf:
|
|
330
|
+
for custom_param_modifier in param_group_modifiers_conf:
|
|
331
|
+
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
|
|
332
|
+
all_scheduler_cfgs = custom_param_modifier(scheduler_cfgs=all_scheduler_cfgs, model=model)
|
|
333
|
+
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs, named_parameters)
|
|
334
|
+
if validate_param_groups:
|
|
335
|
+
validate_param_group_params(param_groups, model)
|
|
336
|
+
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
|
|
337
|
+
return Optimizer(optimizer, schedulers)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def get_full_parameter_name(module_name, param_name):
|
|
341
|
+
if module_name == "":
|
|
342
|
+
return param_name
|
|
343
|
+
return f"{module_name}.{param_name}"
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class GradientClipper:
|
|
347
|
+
"""
|
|
348
|
+
Gradient clipping utils that works for DDP
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
|
|
352
|
+
assert isinstance(max_norm, (int, float)) or max_norm is None
|
|
353
|
+
self.max_norm = max_norm if max_norm is None else float(max_norm)
|
|
354
|
+
self.norm_type = norm_type
|
|
355
|
+
|
|
356
|
+
def __call__(self, model: nn.Module):
|
|
357
|
+
if self.max_norm is None:
|
|
358
|
+
return # no-op
|
|
359
|
+
|
|
360
|
+
nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class ValueScaler:
|
|
364
|
+
def __init__(self, scheduler, mult_val: float):
|
|
365
|
+
self.scheduler = scheduler
|
|
366
|
+
self.mult_val = mult_val
|
|
367
|
+
|
|
368
|
+
def __call__(self, *args, **kwargs):
|
|
369
|
+
val = self.scheduler(*args, **kwargs)
|
|
370
|
+
return val * self.mult_val
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def rgetattr(obj, rattrs: str = None):
|
|
374
|
+
"""
|
|
375
|
+
Like getattr(), but supports dotted notation for nested objects.
|
|
376
|
+
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
|
|
377
|
+
"""
|
|
378
|
+
if rattrs is None:
|
|
379
|
+
return obj
|
|
380
|
+
attrs = rattrs.split(".")
|
|
381
|
+
for attr in attrs:
|
|
382
|
+
obj = getattr(obj, attr)
|
|
383
|
+
return obj
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def layer_decay_param_modifier(
|
|
387
|
+
scheduler_cfgs: List[List[Dict]],
|
|
388
|
+
model,
|
|
389
|
+
layer_decay_value: float,
|
|
390
|
+
layer_decay_min: Optional[float] = None,
|
|
391
|
+
apply_to: Optional[str] = None,
|
|
392
|
+
overrides: List[Dict] = (),
|
|
393
|
+
) -> List[List[Dict]]:
|
|
394
|
+
"""
|
|
395
|
+
Args
|
|
396
|
+
- scheduler_cfgs: a list of omegaconf.ListConfigs.
|
|
397
|
+
Each element in the list is a omegaconfg.DictConfig with the following structure
|
|
398
|
+
{
|
|
399
|
+
"scheduler": <some fvcore scheduler>
|
|
400
|
+
"option": <value> possible options are "lr", "weight_decay" etc.
|
|
401
|
+
"parameter_names": Set of str indicating param names that this scheduler applies to
|
|
402
|
+
}
|
|
403
|
+
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
|
|
404
|
+
and a method get_num_layers.
|
|
405
|
+
Alternatively, use apply_to argument to select a specific component of the model.
|
|
406
|
+
- layer_decay_value: float
|
|
407
|
+
- layer_decay_min: min val for layer decay
|
|
408
|
+
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
|
|
409
|
+
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
|
|
410
|
+
Returns
|
|
411
|
+
- scheduler_configs: same structure as the input, elements can be modified
|
|
412
|
+
"""
|
|
413
|
+
model = rgetattr(model, apply_to)
|
|
414
|
+
num_layers = model.get_num_layers() + 1
|
|
415
|
+
layer_decays = [layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)]
|
|
416
|
+
if layer_decay_min is not None:
|
|
417
|
+
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
|
|
418
|
+
final_scheduler_cfgs = []
|
|
419
|
+
# scheduler_cfgs is a list of lists
|
|
420
|
+
for scheduler_cfg_group in scheduler_cfgs:
|
|
421
|
+
curr_cfg_group = []
|
|
422
|
+
# scheduler_cfg_group is a list of dictionaries
|
|
423
|
+
for scheduler_cfg in scheduler_cfg_group:
|
|
424
|
+
if scheduler_cfg["option"] != "lr":
|
|
425
|
+
curr_cfg_group.append(scheduler_cfg)
|
|
426
|
+
continue
|
|
427
|
+
# Need sorted so that the list of parameter names is deterministic and consistent
|
|
428
|
+
# across re-runs of this job. Else it was causing issues with loading the optimizer
|
|
429
|
+
# state during a job restart (D38591759)
|
|
430
|
+
parameter_names = sorted(scheduler_cfg["parameter_names"])
|
|
431
|
+
|
|
432
|
+
# Only want one cfg group per layer
|
|
433
|
+
layer_cfg_groups = {}
|
|
434
|
+
for param_name in parameter_names:
|
|
435
|
+
layer_id = num_layers
|
|
436
|
+
this_scale = layer_decays[layer_id]
|
|
437
|
+
if param_name.startswith(apply_to):
|
|
438
|
+
layer_id = model.get_layer_id(param_name)
|
|
439
|
+
this_scale = layer_decays[layer_id]
|
|
440
|
+
# Overrides
|
|
441
|
+
for override in overrides:
|
|
442
|
+
if fnmatch.fnmatchcase(param_name, override["pattern"]):
|
|
443
|
+
this_scale = float(override["value"])
|
|
444
|
+
layer_id = override["pattern"]
|
|
445
|
+
break
|
|
446
|
+
|
|
447
|
+
if layer_id not in layer_cfg_groups:
|
|
448
|
+
curr_param = {
|
|
449
|
+
"option": scheduler_cfg["option"],
|
|
450
|
+
"scheduler": ValueScaler(scheduler_cfg["scheduler"], this_scale),
|
|
451
|
+
"parameter_names": {param_name},
|
|
452
|
+
}
|
|
453
|
+
else:
|
|
454
|
+
curr_param = layer_cfg_groups[layer_id]
|
|
455
|
+
curr_param["parameter_names"].add(param_name)
|
|
456
|
+
layer_cfg_groups[layer_id] = curr_param
|
|
457
|
+
|
|
458
|
+
for layer_cfg in layer_cfg_groups.values():
|
|
459
|
+
curr_cfg_group.append(layer_cfg)
|
|
460
|
+
|
|
461
|
+
final_scheduler_cfgs.append(curr_cfg_group)
|
|
462
|
+
return final_scheduler_cfgs
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
import argparse
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import submitit
|
|
11
|
+
import tqdm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_args_parser():
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
description="[SA-V Preprocessing] Extracting JPEG frames",
|
|
17
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# ------------
|
|
21
|
+
# DATA
|
|
22
|
+
# ------------
|
|
23
|
+
data_parser = parser.add_argument_group(
|
|
24
|
+
title="SA-V dataset data root",
|
|
25
|
+
description="What data to load and how to process it.",
|
|
26
|
+
)
|
|
27
|
+
data_parser.add_argument(
|
|
28
|
+
"--sav-vid-dir",
|
|
29
|
+
type=str,
|
|
30
|
+
required=True,
|
|
31
|
+
help=("Where to find the SAV videos"),
|
|
32
|
+
)
|
|
33
|
+
data_parser.add_argument(
|
|
34
|
+
"--sav-frame-sample-rate",
|
|
35
|
+
type=int,
|
|
36
|
+
default=4,
|
|
37
|
+
help="Rate at which to sub-sample frames",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# ------------
|
|
41
|
+
# LAUNCH
|
|
42
|
+
# ------------
|
|
43
|
+
launch_parser = parser.add_argument_group(
|
|
44
|
+
title="Cluster launch settings",
|
|
45
|
+
description="Number of jobs and retry settings.",
|
|
46
|
+
)
|
|
47
|
+
launch_parser.add_argument(
|
|
48
|
+
"--n-jobs",
|
|
49
|
+
type=int,
|
|
50
|
+
required=True,
|
|
51
|
+
help="Shard the run over this many jobs.",
|
|
52
|
+
)
|
|
53
|
+
launch_parser.add_argument("--timeout", type=int, required=True, help="SLURM timeout parameter in minutes.")
|
|
54
|
+
launch_parser.add_argument("--partition", type=str, required=True, help="Partition to launch on.")
|
|
55
|
+
launch_parser.add_argument("--account", type=str, required=True, help="Partition to launch on.")
|
|
56
|
+
launch_parser.add_argument("--qos", type=str, required=True, help="QOS.")
|
|
57
|
+
|
|
58
|
+
# ------------
|
|
59
|
+
# OUTPUT
|
|
60
|
+
# ------------
|
|
61
|
+
output_parser = parser.add_argument_group(
|
|
62
|
+
title="Setting for results output", description="Where and how to save results."
|
|
63
|
+
)
|
|
64
|
+
output_parser.add_argument(
|
|
65
|
+
"--output-dir",
|
|
66
|
+
type=str,
|
|
67
|
+
required=True,
|
|
68
|
+
help=("Where to dump the extracted jpeg frames"),
|
|
69
|
+
)
|
|
70
|
+
output_parser.add_argument(
|
|
71
|
+
"--slurm-output-root-dir",
|
|
72
|
+
type=str,
|
|
73
|
+
required=True,
|
|
74
|
+
help=("Where to save slurm outputs"),
|
|
75
|
+
)
|
|
76
|
+
return parser
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def decode_video(video_path: str):
|
|
80
|
+
assert os.path.exists(video_path)
|
|
81
|
+
video = cv2.VideoCapture(video_path)
|
|
82
|
+
video_frames = []
|
|
83
|
+
while video.isOpened():
|
|
84
|
+
ret, frame = video.read()
|
|
85
|
+
if ret:
|
|
86
|
+
video_frames.append(frame)
|
|
87
|
+
else:
|
|
88
|
+
break
|
|
89
|
+
return video_frames
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def extract_frames(video_path, sample_rate):
|
|
93
|
+
frames = decode_video(video_path)
|
|
94
|
+
return frames[::sample_rate]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def submitit_launch(video_paths, sample_rate, save_root):
|
|
98
|
+
for path in tqdm.tqdm(video_paths):
|
|
99
|
+
frames = extract_frames(path, sample_rate)
|
|
100
|
+
output_folder = os.path.join(save_root, Path(path).stem)
|
|
101
|
+
if not os.path.exists(output_folder):
|
|
102
|
+
os.makedirs(output_folder)
|
|
103
|
+
for fid, frame in enumerate(frames):
|
|
104
|
+
frame_path = os.path.join(output_folder, f"{fid*sample_rate:05d}.jpg")
|
|
105
|
+
cv2.imwrite(frame_path, frame)
|
|
106
|
+
print(f"Saved output to {save_root}")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == "__main__":
|
|
110
|
+
parser = get_args_parser()
|
|
111
|
+
args = parser.parse_args()
|
|
112
|
+
|
|
113
|
+
sav_vid_dir = args.sav_vid_dir
|
|
114
|
+
save_root = args.output_dir
|
|
115
|
+
sample_rate = args.sav_frame_sample_rate
|
|
116
|
+
|
|
117
|
+
# List all SA-V videos
|
|
118
|
+
mp4_files = sorted([str(p) for p in Path(sav_vid_dir).glob("*/*.mp4")])
|
|
119
|
+
mp4_files = np.array(mp4_files)
|
|
120
|
+
chunked_mp4_files = [x.tolist() for x in np.array_split(mp4_files, args.n_jobs)]
|
|
121
|
+
|
|
122
|
+
print(f"Processing videos in: {sav_vid_dir}")
|
|
123
|
+
print(f"Processing {len(mp4_files)} files")
|
|
124
|
+
print(f"Beginning processing in {args.n_jobs} processes")
|
|
125
|
+
|
|
126
|
+
# Submitit params
|
|
127
|
+
jobs_dir = os.path.join(args.slurm_output_root_dir, "%j")
|
|
128
|
+
cpus_per_task = 4
|
|
129
|
+
executor = submitit.AutoExecutor(folder=jobs_dir)
|
|
130
|
+
executor.update_parameters(
|
|
131
|
+
timeout_min=args.timeout,
|
|
132
|
+
gpus_per_node=0,
|
|
133
|
+
tasks_per_node=1,
|
|
134
|
+
slurm_array_parallelism=args.n_jobs,
|
|
135
|
+
cpus_per_task=cpus_per_task,
|
|
136
|
+
slurm_partition=args.partition,
|
|
137
|
+
slurm_account=args.account,
|
|
138
|
+
slurm_qos=args.qos,
|
|
139
|
+
)
|
|
140
|
+
executor.update_parameters(slurm_srun_args=["-vv", "--cpu-bind", "none"])
|
|
141
|
+
|
|
142
|
+
# Launch
|
|
143
|
+
jobs = []
|
|
144
|
+
with executor.batch():
|
|
145
|
+
for _, mp4_chunk in tqdm.tqdm(enumerate(chunked_mp4_files)):
|
|
146
|
+
job = executor.submit(
|
|
147
|
+
submitit_launch,
|
|
148
|
+
video_paths=mp4_chunk,
|
|
149
|
+
sample_rate=sample_rate,
|
|
150
|
+
save_root=save_root,
|
|
151
|
+
)
|
|
152
|
+
jobs.append(job)
|
|
153
|
+
|
|
154
|
+
for j in jobs:
|
|
155
|
+
print(f"Slurm JobID: {j.job_id}")
|
|
156
|
+
print(f"Saving outputs to {save_root}")
|
|
157
|
+
print(f"Slurm outputs at {args.slurm_output_root_dir}")
|