nnInteractive 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nnInteractive/__init__.py +3 -0
- nnInteractive/inference/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
- nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
- nnInteractive/inference/inference_session.py +1400 -0
- nnInteractive/interaction/__init__.py +0 -0
- nnInteractive/interaction/point.py +166 -0
- nnInteractive/supervoxel/setup.py +4 -0
- nnInteractive/supervoxel/src/metadata.py +118 -0
- nnInteractive/supervoxel/src/reader.py +175 -0
- nnInteractive/supervoxel/src/run.py +136 -0
- nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
- nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
- nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
- nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
- nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
- nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
- nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
- nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
- nnInteractive/supervoxel/src/sam2/setup.py +174 -0
- nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
- nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
- nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
- nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
- nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
- nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
- nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
- nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
- nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
- nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
- nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
- nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
- nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
- nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
- nnInteractive/supervoxel/src/supervoxel.py +198 -0
- nnInteractive/trainer/__init__.py +0 -0
- nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
- nnInteractive/utils/__init__.py +0 -0
- nnInteractive/utils/bboxes.py +217 -0
- nnInteractive/utils/checkpoint_cleansing.py +9 -0
- nnInteractive/utils/crop.py +268 -0
- nnInteractive/utils/erosion_dilation.py +48 -0
- nnInteractive/utils/inference_helpers.py +45 -0
- nnInteractive/utils/os_shennanigans.py +16 -0
- nnInteractive/utils/rounding.py +13 -0
- nninteractive-2.0.0.dist-info/METADATA +511 -0
- nninteractive-2.0.0.dist-info/RECORD +76 -0
- nninteractive-2.0.0.dist-info/WHEEL +5 -0
- nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
- nninteractive-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1051 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
|
|
4
|
+
# This source code is licensed under the license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import math
|
|
11
|
+
import os
|
|
12
|
+
import time
|
|
13
|
+
from collections import OrderedDict
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Any, Dict, List, Mapping, Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
from hydra.utils import instantiate
|
|
23
|
+
from iopath.common.file_io import g_pathmgr
|
|
24
|
+
|
|
25
|
+
from training.optimizer import construct_optimizer
|
|
26
|
+
|
|
27
|
+
from training.utils.checkpoint_utils import (
|
|
28
|
+
assert_skipped_parameters_are_frozen,
|
|
29
|
+
exclude_params_matching_unix_pattern,
|
|
30
|
+
load_state_dict_into_model,
|
|
31
|
+
with_check_parameter_frozen,
|
|
32
|
+
)
|
|
33
|
+
from training.utils.data_utils import BatchedVideoDatapoint
|
|
34
|
+
from training.utils.distributed import all_reduce_max, barrier, get_rank
|
|
35
|
+
|
|
36
|
+
from training.utils.logger import Logger, setup_logging
|
|
37
|
+
|
|
38
|
+
from training.utils.train_utils import (
|
|
39
|
+
AverageMeter,
|
|
40
|
+
collect_dict_keys,
|
|
41
|
+
DurationMeter,
|
|
42
|
+
get_amp_type,
|
|
43
|
+
get_machine_local_and_dist_rank,
|
|
44
|
+
get_resume_checkpoint,
|
|
45
|
+
human_readable_time,
|
|
46
|
+
is_dist_avail_and_initialized,
|
|
47
|
+
log_env_variables,
|
|
48
|
+
makedir,
|
|
49
|
+
MemMeter,
|
|
50
|
+
Phase,
|
|
51
|
+
ProgressMeter,
|
|
52
|
+
set_seeds,
|
|
53
|
+
setup_distributed_backend,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
CORE_LOSS_KEY = "core_loss"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def unwrap_ddp_if_wrapped(model):
|
|
60
|
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
|
61
|
+
return model.module
|
|
62
|
+
return model
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class OptimAMPConf:
|
|
67
|
+
enabled: bool = False
|
|
68
|
+
amp_dtype: str = "float16"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class OptimConf:
|
|
73
|
+
optimizer: torch.optim.Optimizer = None
|
|
74
|
+
options: Optional[Dict[str, Any]] = None
|
|
75
|
+
param_group_modifiers: Optional[List] = None
|
|
76
|
+
amp: Optional[Dict[str, Any]] = None
|
|
77
|
+
gradient_clip: Any = None
|
|
78
|
+
gradient_logger: Any = None
|
|
79
|
+
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
# amp
|
|
82
|
+
if not isinstance(self.amp, OptimAMPConf):
|
|
83
|
+
if self.amp is None:
|
|
84
|
+
self.amp = {}
|
|
85
|
+
assert isinstance(self.amp, Mapping)
|
|
86
|
+
self.amp = OptimAMPConf(**self.amp)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class DistributedConf:
|
|
91
|
+
backend: Optional[str] = None # inferred from accelerator type
|
|
92
|
+
comms_dtype: Optional[str] = None
|
|
93
|
+
find_unused_parameters: bool = False
|
|
94
|
+
timeout_mins: int = 30
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class CudaConf:
|
|
99
|
+
cudnn_deterministic: bool = False
|
|
100
|
+
cudnn_benchmark: bool = True
|
|
101
|
+
allow_tf32: bool = False
|
|
102
|
+
# if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
|
|
103
|
+
matmul_allow_tf32: Optional[bool] = None
|
|
104
|
+
# if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
|
|
105
|
+
cudnn_allow_tf32: Optional[bool] = None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class CheckpointConf:
|
|
110
|
+
save_dir: str
|
|
111
|
+
save_freq: int
|
|
112
|
+
save_list: List[int] = field(default_factory=list)
|
|
113
|
+
model_weight_initializer: Any = None
|
|
114
|
+
save_best_meters: List[str] = None
|
|
115
|
+
skip_saving_parameters: List[str] = field(default_factory=list)
|
|
116
|
+
initialize_after_preemption: Optional[bool] = None
|
|
117
|
+
# if not None, training will be resumed from this checkpoint
|
|
118
|
+
resume_from: Optional[str] = None
|
|
119
|
+
|
|
120
|
+
def infer_missing(self):
|
|
121
|
+
if self.initialize_after_preemption is None:
|
|
122
|
+
with_skip_saving = len(self.skip_saving_parameters) > 0
|
|
123
|
+
self.initialize_after_preemption = with_skip_saving
|
|
124
|
+
return self
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class LoggingConf:
|
|
129
|
+
log_dir: str
|
|
130
|
+
log_freq: int # In iterations
|
|
131
|
+
tensorboard_writer: Any
|
|
132
|
+
log_level_primary: str = "INFO"
|
|
133
|
+
log_level_secondary: str = "ERROR"
|
|
134
|
+
log_scalar_frequency: int = 100
|
|
135
|
+
log_visual_frequency: int = 100
|
|
136
|
+
scalar_keys_to_log: Optional[Dict[str, Any]] = None
|
|
137
|
+
log_batch_stats: bool = False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Trainer:
|
|
141
|
+
"""
|
|
142
|
+
Trainer supporting the DDP training strategies.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
EPSILON = 1e-8
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
*, # the order of these args can change at any time, so they are keyword-only
|
|
150
|
+
data: Dict[str, Any],
|
|
151
|
+
model: Dict[str, Any],
|
|
152
|
+
logging: Dict[str, Any],
|
|
153
|
+
checkpoint: Dict[str, Any],
|
|
154
|
+
max_epochs: int,
|
|
155
|
+
mode: str = "train",
|
|
156
|
+
accelerator: str = "cuda",
|
|
157
|
+
seed_value: int = 123,
|
|
158
|
+
val_epoch_freq: int = 1,
|
|
159
|
+
distributed: Dict[str, bool] = None,
|
|
160
|
+
cuda: Dict[str, bool] = None,
|
|
161
|
+
env_variables: Optional[Dict[str, Any]] = None,
|
|
162
|
+
optim: Optional[Dict[str, Any]] = None,
|
|
163
|
+
optim_overrides: Optional[List[Dict[str, Any]]] = None,
|
|
164
|
+
meters: Optional[Dict[str, Any]] = None,
|
|
165
|
+
loss: Optional[Dict[str, Any]] = None,
|
|
166
|
+
):
|
|
167
|
+
|
|
168
|
+
self._setup_env_variables(env_variables)
|
|
169
|
+
self._setup_timers()
|
|
170
|
+
|
|
171
|
+
self.data_conf = data
|
|
172
|
+
self.model_conf = model
|
|
173
|
+
self.logging_conf = LoggingConf(**logging)
|
|
174
|
+
self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
|
|
175
|
+
self.max_epochs = max_epochs
|
|
176
|
+
self.mode = mode
|
|
177
|
+
self.val_epoch_freq = val_epoch_freq
|
|
178
|
+
self.optim_conf = OptimConf(**optim) if optim is not None else None
|
|
179
|
+
self.meters_conf = meters
|
|
180
|
+
self.loss_conf = loss
|
|
181
|
+
distributed = DistributedConf(**distributed or {})
|
|
182
|
+
cuda = CudaConf(**cuda or {})
|
|
183
|
+
self.where = 0.0
|
|
184
|
+
|
|
185
|
+
self._infer_distributed_backend_if_none(distributed, accelerator)
|
|
186
|
+
|
|
187
|
+
self._setup_device(accelerator)
|
|
188
|
+
|
|
189
|
+
self._setup_torch_dist_and_backend(cuda, distributed)
|
|
190
|
+
|
|
191
|
+
makedir(self.logging_conf.log_dir)
|
|
192
|
+
setup_logging(
|
|
193
|
+
__name__,
|
|
194
|
+
output_dir=self.logging_conf.log_dir,
|
|
195
|
+
rank=self.rank,
|
|
196
|
+
log_level_primary=self.logging_conf.log_level_primary,
|
|
197
|
+
log_level_secondary=self.logging_conf.log_level_secondary,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
|
201
|
+
log_env_variables()
|
|
202
|
+
|
|
203
|
+
assert is_dist_avail_and_initialized(), "Torch distributed needs to be initialized before calling the trainer."
|
|
204
|
+
|
|
205
|
+
self._setup_components() # Except Optimizer everything is setup here.
|
|
206
|
+
self._move_to_device()
|
|
207
|
+
self._construct_optimizers()
|
|
208
|
+
self._setup_dataloaders()
|
|
209
|
+
|
|
210
|
+
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
|
211
|
+
|
|
212
|
+
if self.checkpoint_conf.resume_from is not None:
|
|
213
|
+
assert os.path.exists(
|
|
214
|
+
self.checkpoint_conf.resume_from
|
|
215
|
+
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
|
216
|
+
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
|
217
|
+
if self.distributed_rank == 0 and not os.path.exists(dst):
|
|
218
|
+
# Copy the "resume_from" checkpoint to the checkpoint folder
|
|
219
|
+
# if there is not a checkpoint to resume from already there
|
|
220
|
+
makedir(self.checkpoint_conf.save_dir)
|
|
221
|
+
g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
|
|
222
|
+
barrier()
|
|
223
|
+
|
|
224
|
+
self.load_checkpoint()
|
|
225
|
+
self._setup_ddp_distributed_training(distributed, accelerator)
|
|
226
|
+
barrier()
|
|
227
|
+
|
|
228
|
+
def _setup_timers(self):
|
|
229
|
+
"""
|
|
230
|
+
Initializes counters for elapsed time and eta.
|
|
231
|
+
"""
|
|
232
|
+
self.start_time = time.time()
|
|
233
|
+
self.ckpt_time_elapsed = 0
|
|
234
|
+
self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
|
|
235
|
+
|
|
236
|
+
def _get_meters(self, phase_filters=None):
|
|
237
|
+
if self.meters is None:
|
|
238
|
+
return {}
|
|
239
|
+
meters = {}
|
|
240
|
+
for phase, phase_meters in self.meters.items():
|
|
241
|
+
if phase_filters is not None and phase not in phase_filters:
|
|
242
|
+
continue
|
|
243
|
+
for key, key_meters in phase_meters.items():
|
|
244
|
+
if key_meters is None:
|
|
245
|
+
continue
|
|
246
|
+
for name, meter in key_meters.items():
|
|
247
|
+
meters[f"{phase}_{key}/{name}"] = meter
|
|
248
|
+
return meters
|
|
249
|
+
|
|
250
|
+
def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
|
|
251
|
+
if distributed_conf.backend is None:
|
|
252
|
+
distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
|
|
253
|
+
|
|
254
|
+
def _setup_env_variables(self, env_variables_conf) -> None:
|
|
255
|
+
if env_variables_conf is not None:
|
|
256
|
+
for variable_name, value in env_variables_conf.items():
|
|
257
|
+
os.environ[variable_name] = value
|
|
258
|
+
|
|
259
|
+
def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
|
|
260
|
+
if torch.cuda.is_available():
|
|
261
|
+
torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
|
|
262
|
+
torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
|
|
263
|
+
torch.backends.cuda.matmul.allow_tf32 = (
|
|
264
|
+
cuda_conf.matmul_allow_tf32 if cuda_conf.matmul_allow_tf32 is not None else cuda_conf.allow_tf32
|
|
265
|
+
)
|
|
266
|
+
torch.backends.cudnn.allow_tf32 = (
|
|
267
|
+
cuda_conf.cudnn_allow_tf32 if cuda_conf.cudnn_allow_tf32 is not None else cuda_conf.allow_tf32
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
self.rank = setup_distributed_backend(distributed_conf.backend, distributed_conf.timeout_mins)
|
|
271
|
+
|
|
272
|
+
def _setup_device(self, accelerator):
|
|
273
|
+
self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
|
|
274
|
+
if accelerator == "cuda":
|
|
275
|
+
self.device = torch.device("cuda", self.local_rank)
|
|
276
|
+
torch.cuda.set_device(self.local_rank)
|
|
277
|
+
elif accelerator == "cpu":
|
|
278
|
+
self.device = torch.device("cpu")
|
|
279
|
+
else:
|
|
280
|
+
raise ValueError(f"Unsupported accelerator: {accelerator}")
|
|
281
|
+
|
|
282
|
+
def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
|
|
283
|
+
|
|
284
|
+
assert isinstance(self.model, torch.nn.Module)
|
|
285
|
+
|
|
286
|
+
self.model = nn.parallel.DistributedDataParallel(
|
|
287
|
+
self.model,
|
|
288
|
+
device_ids=[self.local_rank] if accelerator == "cuda" else [],
|
|
289
|
+
find_unused_parameters=distributed_conf.find_unused_parameters,
|
|
290
|
+
)
|
|
291
|
+
if distributed_conf.comms_dtype is not None: # noqa
|
|
292
|
+
from torch.distributed.algorithms import ddp_comm_hooks
|
|
293
|
+
|
|
294
|
+
amp_type = get_amp_type(distributed_conf.comms_dtype)
|
|
295
|
+
if amp_type == torch.bfloat16:
|
|
296
|
+
hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
|
|
297
|
+
logging.info("Enabling bfloat16 grad communication")
|
|
298
|
+
else:
|
|
299
|
+
hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
|
|
300
|
+
logging.info("Enabling fp16 grad communication")
|
|
301
|
+
process_group = None
|
|
302
|
+
self.model.register_comm_hook(process_group, hook)
|
|
303
|
+
|
|
304
|
+
def _move_to_device(self):
|
|
305
|
+
logging.info(f"Moving components to device {self.device} and local rank {self.local_rank}.")
|
|
306
|
+
|
|
307
|
+
self.model.to(self.device)
|
|
308
|
+
|
|
309
|
+
logging.info(f"Done moving components to device {self.device} and local rank {self.local_rank}.")
|
|
310
|
+
|
|
311
|
+
def save_checkpoint(self, epoch, checkpoint_names=None):
|
|
312
|
+
checkpoint_folder = self.checkpoint_conf.save_dir
|
|
313
|
+
makedir(checkpoint_folder)
|
|
314
|
+
if checkpoint_names is None:
|
|
315
|
+
checkpoint_names = ["checkpoint"]
|
|
316
|
+
if (self.checkpoint_conf.save_freq > 0 and (int(epoch) % self.checkpoint_conf.save_freq == 0)) or int(
|
|
317
|
+
epoch
|
|
318
|
+
) in self.checkpoint_conf.save_list:
|
|
319
|
+
checkpoint_names.append(f"checkpoint_{int(epoch)}")
|
|
320
|
+
|
|
321
|
+
checkpoint_paths = []
|
|
322
|
+
for ckpt_name in checkpoint_names:
|
|
323
|
+
checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
|
|
324
|
+
|
|
325
|
+
state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
|
|
326
|
+
state_dict = exclude_params_matching_unix_pattern(
|
|
327
|
+
patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
checkpoint = {
|
|
331
|
+
"model": state_dict,
|
|
332
|
+
"optimizer": self.optim.optimizer.state_dict(),
|
|
333
|
+
"epoch": epoch,
|
|
334
|
+
"loss": self.loss.state_dict(),
|
|
335
|
+
"steps": self.steps,
|
|
336
|
+
"time_elapsed": self.time_elapsed_meter.val,
|
|
337
|
+
"best_meter_values": self.best_meter_values,
|
|
338
|
+
}
|
|
339
|
+
if self.optim_conf.amp.enabled:
|
|
340
|
+
checkpoint["scaler"] = self.scaler.state_dict()
|
|
341
|
+
|
|
342
|
+
# DDP checkpoints are only saved on rank 0 (all workers are identical)
|
|
343
|
+
if self.distributed_rank != 0:
|
|
344
|
+
return
|
|
345
|
+
|
|
346
|
+
for checkpoint_path in checkpoint_paths:
|
|
347
|
+
self._save_checkpoint(checkpoint, checkpoint_path)
|
|
348
|
+
|
|
349
|
+
def _save_checkpoint(self, checkpoint, checkpoint_path):
|
|
350
|
+
"""
|
|
351
|
+
Save a checkpoint while guarding against the job being killed in the middle
|
|
352
|
+
of checkpoint saving (which corrupts the checkpoint file and ruins the
|
|
353
|
+
entire training since usually only the last checkpoint is kept per run).
|
|
354
|
+
|
|
355
|
+
We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
|
|
356
|
+
and move it to overwrite the old checkpoint_path.
|
|
357
|
+
"""
|
|
358
|
+
checkpoint_path_tmp = f"{checkpoint_path}.tmp"
|
|
359
|
+
with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
|
|
360
|
+
torch.save(checkpoint, f)
|
|
361
|
+
# after torch.save is completed, replace the old checkpoint with the new one
|
|
362
|
+
if g_pathmgr.exists(checkpoint_path):
|
|
363
|
+
# remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
|
|
364
|
+
g_pathmgr.rm(checkpoint_path)
|
|
365
|
+
success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
|
|
366
|
+
assert success
|
|
367
|
+
|
|
368
|
+
def load_checkpoint(self):
|
|
369
|
+
ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
|
|
370
|
+
if ckpt_path is None:
|
|
371
|
+
self._init_model_state()
|
|
372
|
+
else:
|
|
373
|
+
if self.checkpoint_conf.initialize_after_preemption:
|
|
374
|
+
self._call_model_initializer()
|
|
375
|
+
self._load_resuming_checkpoint(ckpt_path)
|
|
376
|
+
|
|
377
|
+
def _init_model_state(self):
|
|
378
|
+
# Checking that parameters that won't be saved are indeed frozen
|
|
379
|
+
# We do this check here before even saving the model to catch errors
|
|
380
|
+
# are early as possible and not at the end of the first epoch
|
|
381
|
+
assert_skipped_parameters_are_frozen(
|
|
382
|
+
patterns=self.checkpoint_conf.skip_saving_parameters,
|
|
383
|
+
model=self.model,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Checking that parameters that won't be saved are initialized from
|
|
387
|
+
# within the model definition, unless `initialize_after_preemption`
|
|
388
|
+
# is explicitly set to `True`. If not, this is a bug, and after
|
|
389
|
+
# preemption, the `skip_saving_parameters` will have random values
|
|
390
|
+
allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
|
|
391
|
+
with with_check_parameter_frozen(
|
|
392
|
+
patterns=self.checkpoint_conf.skip_saving_parameters,
|
|
393
|
+
model=self.model,
|
|
394
|
+
disabled=allow_init_skip_parameters,
|
|
395
|
+
):
|
|
396
|
+
self._call_model_initializer()
|
|
397
|
+
|
|
398
|
+
def _call_model_initializer(self):
|
|
399
|
+
model_weight_initializer = instantiate(self.checkpoint_conf.model_weight_initializer)
|
|
400
|
+
if model_weight_initializer is not None:
|
|
401
|
+
logging.info(f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}")
|
|
402
|
+
self.model = model_weight_initializer(model=self.model)
|
|
403
|
+
|
|
404
|
+
def _load_resuming_checkpoint(self, ckpt_path: str):
|
|
405
|
+
logging.info(f"Resuming training from {ckpt_path}")
|
|
406
|
+
|
|
407
|
+
with g_pathmgr.open(ckpt_path, "rb") as f:
|
|
408
|
+
checkpoint = torch.load(f, map_location="cpu")
|
|
409
|
+
load_state_dict_into_model(
|
|
410
|
+
model=self.model,
|
|
411
|
+
state_dict=checkpoint["model"],
|
|
412
|
+
ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
|
|
416
|
+
self.loss.load_state_dict(checkpoint["loss"], strict=True)
|
|
417
|
+
self.epoch = checkpoint["epoch"]
|
|
418
|
+
self.steps = checkpoint["steps"]
|
|
419
|
+
self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
|
|
420
|
+
|
|
421
|
+
if self.optim_conf.amp.enabled and "scaler" in checkpoint:
|
|
422
|
+
self.scaler.load_state_dict(checkpoint["scaler"])
|
|
423
|
+
|
|
424
|
+
self.best_meter_values = checkpoint.get("best_meter_values", {})
|
|
425
|
+
|
|
426
|
+
if "train_dataset" in checkpoint and self.train_dataset is not None:
|
|
427
|
+
self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
|
|
428
|
+
|
|
429
|
+
def is_intermediate_val_epoch(self, epoch):
|
|
430
|
+
return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1
|
|
431
|
+
|
|
432
|
+
def _step(
|
|
433
|
+
self,
|
|
434
|
+
batch: BatchedVideoDatapoint,
|
|
435
|
+
model: nn.Module,
|
|
436
|
+
phase: str,
|
|
437
|
+
):
|
|
438
|
+
|
|
439
|
+
outputs = model(batch)
|
|
440
|
+
targets = batch.masks
|
|
441
|
+
batch_size = len(batch.img_batch)
|
|
442
|
+
|
|
443
|
+
key = batch.dict_key # key for dataset
|
|
444
|
+
loss = self.loss[key](outputs, targets)
|
|
445
|
+
loss_str = f"Losses/{phase}_{key}_loss"
|
|
446
|
+
|
|
447
|
+
loss_log_str = os.path.join("Step_Losses", loss_str)
|
|
448
|
+
|
|
449
|
+
# loss contains multiple sub-components we wish to log
|
|
450
|
+
step_losses = {}
|
|
451
|
+
if isinstance(loss, dict):
|
|
452
|
+
step_losses.update({f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()})
|
|
453
|
+
loss = self._log_loss_detailed_and_return_core_loss(loss, loss_log_str, self.steps[phase])
|
|
454
|
+
|
|
455
|
+
if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
|
|
456
|
+
self.logger.log(
|
|
457
|
+
loss_log_str,
|
|
458
|
+
loss,
|
|
459
|
+
self.steps[phase],
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
self.steps[phase] += 1
|
|
463
|
+
|
|
464
|
+
ret_tuple = {loss_str: loss}, batch_size, step_losses
|
|
465
|
+
|
|
466
|
+
if phase in self.meters and key in self.meters[phase]:
|
|
467
|
+
meters_dict = self.meters[phase][key]
|
|
468
|
+
if meters_dict is not None:
|
|
469
|
+
for _, meter in meters_dict.items():
|
|
470
|
+
meter.update(
|
|
471
|
+
find_stages=outputs,
|
|
472
|
+
find_metadatas=batch.metadata,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
return ret_tuple
|
|
476
|
+
|
|
477
|
+
def run(self):
|
|
478
|
+
assert self.mode in ["train", "train_only", "val"]
|
|
479
|
+
if self.mode == "train":
|
|
480
|
+
if self.epoch > 0:
|
|
481
|
+
logging.info(f"Resuming training from epoch: {self.epoch}")
|
|
482
|
+
# resuming from a checkpoint
|
|
483
|
+
if self.is_intermediate_val_epoch(self.epoch - 1):
|
|
484
|
+
logging.info("Running previous val epoch")
|
|
485
|
+
self.epoch -= 1
|
|
486
|
+
self.run_val()
|
|
487
|
+
self.epoch += 1
|
|
488
|
+
self.run_train()
|
|
489
|
+
self.run_val()
|
|
490
|
+
elif self.mode == "val":
|
|
491
|
+
self.run_val()
|
|
492
|
+
elif self.mode == "train_only":
|
|
493
|
+
self.run_train()
|
|
494
|
+
|
|
495
|
+
def _setup_dataloaders(self):
|
|
496
|
+
self.train_dataset = None
|
|
497
|
+
self.val_dataset = None
|
|
498
|
+
|
|
499
|
+
if self.mode in ["train", "val"]:
|
|
500
|
+
self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
|
|
501
|
+
|
|
502
|
+
if self.mode in ["train", "train_only"]:
|
|
503
|
+
self.train_dataset = instantiate(self.data_conf.train)
|
|
504
|
+
|
|
505
|
+
def run_train(self):
|
|
506
|
+
|
|
507
|
+
while self.epoch < self.max_epochs:
|
|
508
|
+
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
|
|
509
|
+
barrier()
|
|
510
|
+
outs = self.train_epoch(dataloader)
|
|
511
|
+
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
|
|
512
|
+
|
|
513
|
+
# log train to text file.
|
|
514
|
+
if self.distributed_rank == 0:
|
|
515
|
+
with g_pathmgr.open(
|
|
516
|
+
os.path.join(self.logging_conf.log_dir, "train_stats.json"),
|
|
517
|
+
"a",
|
|
518
|
+
) as f:
|
|
519
|
+
f.write(json.dumps(outs) + "\n")
|
|
520
|
+
|
|
521
|
+
# Save checkpoint before validating
|
|
522
|
+
self.save_checkpoint(self.epoch + 1)
|
|
523
|
+
|
|
524
|
+
del dataloader
|
|
525
|
+
gc.collect()
|
|
526
|
+
|
|
527
|
+
# Run val, not running on last epoch since will run after the
|
|
528
|
+
# loop anyway
|
|
529
|
+
if self.is_intermediate_val_epoch(self.epoch):
|
|
530
|
+
self.run_val()
|
|
531
|
+
|
|
532
|
+
if self.distributed_rank == 0:
|
|
533
|
+
self.best_meter_values.update(self._get_trainer_state("train"))
|
|
534
|
+
with g_pathmgr.open(
|
|
535
|
+
os.path.join(self.logging_conf.log_dir, "best_stats.json"),
|
|
536
|
+
"a",
|
|
537
|
+
) as f:
|
|
538
|
+
f.write(json.dumps(self.best_meter_values) + "\n")
|
|
539
|
+
|
|
540
|
+
self.epoch += 1
|
|
541
|
+
# epoch was incremented in the loop but the val step runs out of the loop
|
|
542
|
+
self.epoch -= 1
|
|
543
|
+
|
|
544
|
+
def run_val(self):
|
|
545
|
+
if not self.val_dataset:
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
|
|
549
|
+
outs = self.val_epoch(dataloader, phase=Phase.VAL)
|
|
550
|
+
del dataloader
|
|
551
|
+
gc.collect()
|
|
552
|
+
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
|
|
553
|
+
|
|
554
|
+
if self.distributed_rank == 0:
|
|
555
|
+
with g_pathmgr.open(
|
|
556
|
+
os.path.join(self.logging_conf.log_dir, "val_stats.json"),
|
|
557
|
+
"a",
|
|
558
|
+
) as f:
|
|
559
|
+
f.write(json.dumps(outs) + "\n")
|
|
560
|
+
|
|
561
|
+
def val_epoch(self, val_loader, phase):
|
|
562
|
+
batch_time = AverageMeter("Batch Time", self.device, ":.2f")
|
|
563
|
+
data_time = AverageMeter("Data Time", self.device, ":.2f")
|
|
564
|
+
mem = MemMeter("Mem (GB)", self.device, ":.2f")
|
|
565
|
+
|
|
566
|
+
iters_per_epoch = len(val_loader)
|
|
567
|
+
|
|
568
|
+
curr_phases = [phase]
|
|
569
|
+
curr_models = [self.model]
|
|
570
|
+
|
|
571
|
+
loss_names = []
|
|
572
|
+
for p in curr_phases:
|
|
573
|
+
for key in self.loss.keys():
|
|
574
|
+
loss_names.append(f"Losses/{p}_{key}_loss")
|
|
575
|
+
|
|
576
|
+
loss_mts = OrderedDict([(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names])
|
|
577
|
+
extra_loss_mts = {}
|
|
578
|
+
|
|
579
|
+
for model in curr_models:
|
|
580
|
+
model.eval()
|
|
581
|
+
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
|
|
582
|
+
unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
|
|
583
|
+
|
|
584
|
+
progress = ProgressMeter(
|
|
585
|
+
iters_per_epoch,
|
|
586
|
+
[batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
|
|
587
|
+
self._get_meters(curr_phases),
|
|
588
|
+
prefix="Val Epoch: [{}]".format(self.epoch),
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
end = time.time()
|
|
592
|
+
|
|
593
|
+
for data_iter, batch in enumerate(val_loader):
|
|
594
|
+
|
|
595
|
+
# measure data loading time
|
|
596
|
+
data_time.update(time.time() - end)
|
|
597
|
+
|
|
598
|
+
batch = batch.to(self.device, non_blocking=True)
|
|
599
|
+
|
|
600
|
+
# compute output
|
|
601
|
+
with torch.no_grad():
|
|
602
|
+
with torch.cuda.amp.autocast(
|
|
603
|
+
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
|
|
604
|
+
dtype=(get_amp_type(self.optim_conf.amp.amp_dtype) if self.optim_conf else None),
|
|
605
|
+
):
|
|
606
|
+
for phase, model in zip(curr_phases, curr_models):
|
|
607
|
+
loss_dict, batch_size, extra_losses = self._step(
|
|
608
|
+
batch,
|
|
609
|
+
model,
|
|
610
|
+
phase,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
assert len(loss_dict) == 1
|
|
614
|
+
loss_key, loss = loss_dict.popitem()
|
|
615
|
+
|
|
616
|
+
loss_mts[loss_key].update(loss.item(), batch_size)
|
|
617
|
+
|
|
618
|
+
for k, v in extra_losses.items():
|
|
619
|
+
if k not in extra_loss_mts:
|
|
620
|
+
extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
|
|
621
|
+
extra_loss_mts[k].update(v.item(), batch_size)
|
|
622
|
+
|
|
623
|
+
# measure elapsed time
|
|
624
|
+
batch_time.update(time.time() - end)
|
|
625
|
+
end = time.time()
|
|
626
|
+
|
|
627
|
+
self.time_elapsed_meter.update(time.time() - self.start_time + self.ckpt_time_elapsed)
|
|
628
|
+
|
|
629
|
+
if torch.cuda.is_available():
|
|
630
|
+
mem.update(reset_peak_usage=True)
|
|
631
|
+
|
|
632
|
+
if data_iter % self.logging_conf.log_freq == 0:
|
|
633
|
+
progress.display(data_iter)
|
|
634
|
+
|
|
635
|
+
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
|
636
|
+
# Log progress meters.
|
|
637
|
+
for progress_meter in progress.meters:
|
|
638
|
+
self.logger.log(
|
|
639
|
+
os.path.join("Step_Stats", phase, progress_meter.name),
|
|
640
|
+
progress_meter.val,
|
|
641
|
+
self.steps[Phase.VAL],
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if data_iter % 10 == 0:
|
|
645
|
+
dist.barrier()
|
|
646
|
+
|
|
647
|
+
self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
|
|
648
|
+
self._log_timers(phase)
|
|
649
|
+
for model in curr_models:
|
|
650
|
+
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
|
|
651
|
+
unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
|
|
652
|
+
|
|
653
|
+
out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
|
|
654
|
+
|
|
655
|
+
for k, v in loss_mts.items():
|
|
656
|
+
out_dict[k] = v.avg
|
|
657
|
+
for k, v in extra_loss_mts.items():
|
|
658
|
+
out_dict[k] = v.avg
|
|
659
|
+
|
|
660
|
+
for phase in curr_phases:
|
|
661
|
+
out_dict.update(self._get_trainer_state(phase))
|
|
662
|
+
self._reset_meters(curr_phases)
|
|
663
|
+
logging.info(f"Meters: {out_dict}")
|
|
664
|
+
return out_dict
|
|
665
|
+
|
|
666
|
+
def _get_trainer_state(self, phase):
|
|
667
|
+
return {
|
|
668
|
+
"Trainer/where": self.where,
|
|
669
|
+
"Trainer/epoch": self.epoch,
|
|
670
|
+
f"Trainer/steps_{phase}": self.steps[phase],
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
def train_epoch(self, train_loader):
|
|
674
|
+
|
|
675
|
+
# Init stat meters
|
|
676
|
+
batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
|
|
677
|
+
data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
|
|
678
|
+
mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
|
|
679
|
+
data_times = []
|
|
680
|
+
phase = Phase.TRAIN
|
|
681
|
+
|
|
682
|
+
iters_per_epoch = len(train_loader)
|
|
683
|
+
|
|
684
|
+
loss_names = []
|
|
685
|
+
for batch_key in self.loss.keys():
|
|
686
|
+
loss_names.append(f"Losses/{phase}_{batch_key}_loss")
|
|
687
|
+
|
|
688
|
+
loss_mts = OrderedDict([(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names])
|
|
689
|
+
extra_loss_mts = {}
|
|
690
|
+
|
|
691
|
+
progress = ProgressMeter(
|
|
692
|
+
iters_per_epoch,
|
|
693
|
+
[
|
|
694
|
+
batch_time_meter,
|
|
695
|
+
data_time_meter,
|
|
696
|
+
mem_meter,
|
|
697
|
+
self.time_elapsed_meter,
|
|
698
|
+
*loss_mts.values(),
|
|
699
|
+
],
|
|
700
|
+
self._get_meters([phase]),
|
|
701
|
+
prefix="Train Epoch: [{}]".format(self.epoch),
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Model training loop
|
|
705
|
+
self.model.train()
|
|
706
|
+
end = time.time()
|
|
707
|
+
|
|
708
|
+
for data_iter, batch in enumerate(train_loader):
|
|
709
|
+
# measure data loading time
|
|
710
|
+
data_time_meter.update(time.time() - end)
|
|
711
|
+
data_times.append(data_time_meter.val)
|
|
712
|
+
batch = batch.to(self.device, non_blocking=True) # move tensors in a tensorclass
|
|
713
|
+
|
|
714
|
+
try:
|
|
715
|
+
self._run_step(batch, phase, loss_mts, extra_loss_mts)
|
|
716
|
+
|
|
717
|
+
# compute gradient and do optim step
|
|
718
|
+
exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
|
|
719
|
+
self.where = float(exact_epoch) / self.max_epochs
|
|
720
|
+
assert self.where <= 1 + self.EPSILON
|
|
721
|
+
if self.where < 1.0:
|
|
722
|
+
self.optim.step_schedulers(self.where, step=int(exact_epoch * iters_per_epoch))
|
|
723
|
+
else:
|
|
724
|
+
logging.warning(
|
|
725
|
+
f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
# Log schedulers
|
|
729
|
+
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
|
730
|
+
for j, param_group in enumerate(self.optim.optimizer.param_groups):
|
|
731
|
+
for option in self.optim.schedulers[j]:
|
|
732
|
+
optim_prefix = "" + f"{j}_" if len(self.optim.optimizer.param_groups) > 1 else ""
|
|
733
|
+
self.logger.log(
|
|
734
|
+
os.path.join("Optim", f"{optim_prefix}", option),
|
|
735
|
+
param_group[option],
|
|
736
|
+
self.steps[phase],
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
# Clipping gradients and detecting diverging gradients
|
|
740
|
+
if self.gradient_clipper is not None:
|
|
741
|
+
self.scaler.unscale_(self.optim.optimizer)
|
|
742
|
+
self.gradient_clipper(model=self.model)
|
|
743
|
+
|
|
744
|
+
if self.gradient_logger is not None:
|
|
745
|
+
self.gradient_logger(self.model, rank=self.distributed_rank, where=self.where)
|
|
746
|
+
|
|
747
|
+
# Optimizer step: the scaler will make sure gradients are not
|
|
748
|
+
# applied if the gradients are infinite
|
|
749
|
+
self.scaler.step(self.optim.optimizer)
|
|
750
|
+
self.scaler.update()
|
|
751
|
+
|
|
752
|
+
# measure elapsed time
|
|
753
|
+
batch_time_meter.update(time.time() - end)
|
|
754
|
+
end = time.time()
|
|
755
|
+
|
|
756
|
+
self.time_elapsed_meter.update(time.time() - self.start_time + self.ckpt_time_elapsed)
|
|
757
|
+
|
|
758
|
+
mem_meter.update(reset_peak_usage=True)
|
|
759
|
+
if data_iter % self.logging_conf.log_freq == 0:
|
|
760
|
+
progress.display(data_iter)
|
|
761
|
+
|
|
762
|
+
if data_iter % self.logging_conf.log_scalar_frequency == 0:
|
|
763
|
+
# Log progress meters.
|
|
764
|
+
for progress_meter in progress.meters:
|
|
765
|
+
self.logger.log(
|
|
766
|
+
os.path.join("Step_Stats", phase, progress_meter.name),
|
|
767
|
+
progress_meter.val,
|
|
768
|
+
self.steps[phase],
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Catching NaN/Inf errors in the loss
|
|
772
|
+
except FloatingPointError as e:
|
|
773
|
+
raise e
|
|
774
|
+
|
|
775
|
+
self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
|
|
776
|
+
self._log_timers(Phase.TRAIN)
|
|
777
|
+
self._log_sync_data_times(Phase.TRAIN, data_times)
|
|
778
|
+
|
|
779
|
+
out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
|
|
780
|
+
|
|
781
|
+
for k, v in loss_mts.items():
|
|
782
|
+
out_dict[k] = v.avg
|
|
783
|
+
for k, v in extra_loss_mts.items():
|
|
784
|
+
out_dict[k] = v.avg
|
|
785
|
+
out_dict.update(self._get_trainer_state(phase))
|
|
786
|
+
logging.info(f"Losses and meters: {out_dict}")
|
|
787
|
+
self._reset_meters([phase])
|
|
788
|
+
return out_dict
|
|
789
|
+
|
|
790
|
+
def _log_sync_data_times(self, phase, data_times):
|
|
791
|
+
data_times = all_reduce_max(torch.tensor(data_times)).tolist()
|
|
792
|
+
steps = range(self.steps[phase] - len(data_times), self.steps[phase])
|
|
793
|
+
for step, data_time in zip(steps, data_times):
|
|
794
|
+
if step % self.logging_conf.log_scalar_frequency == 0:
|
|
795
|
+
self.logger.log(
|
|
796
|
+
os.path.join("Step_Stats", phase, "Data Time Synced"),
|
|
797
|
+
data_time,
|
|
798
|
+
step,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
def _run_step(
|
|
802
|
+
self,
|
|
803
|
+
batch: BatchedVideoDatapoint,
|
|
804
|
+
phase: str,
|
|
805
|
+
loss_mts: Dict[str, AverageMeter],
|
|
806
|
+
extra_loss_mts: Dict[str, AverageMeter],
|
|
807
|
+
raise_on_error: bool = True,
|
|
808
|
+
):
|
|
809
|
+
"""
|
|
810
|
+
Run the forward / backward
|
|
811
|
+
"""
|
|
812
|
+
|
|
813
|
+
# it's important to set grads to None, especially with Adam since 0
|
|
814
|
+
# grads will also update a model even if the step doesn't produce
|
|
815
|
+
# gradients
|
|
816
|
+
self.optim.zero_grad(set_to_none=True)
|
|
817
|
+
with torch.cuda.amp.autocast(
|
|
818
|
+
enabled=self.optim_conf.amp.enabled,
|
|
819
|
+
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
|
|
820
|
+
):
|
|
821
|
+
loss_dict, batch_size, extra_losses = self._step(
|
|
822
|
+
batch,
|
|
823
|
+
self.model,
|
|
824
|
+
phase,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
assert len(loss_dict) == 1
|
|
828
|
+
loss_key, loss = loss_dict.popitem()
|
|
829
|
+
|
|
830
|
+
if not math.isfinite(loss.item()):
|
|
831
|
+
error_msg = f"Loss is {loss.item()}, attempting to stop training"
|
|
832
|
+
logging.error(error_msg)
|
|
833
|
+
if raise_on_error:
|
|
834
|
+
raise FloatingPointError(error_msg)
|
|
835
|
+
else:
|
|
836
|
+
return
|
|
837
|
+
|
|
838
|
+
self.scaler.scale(loss).backward()
|
|
839
|
+
loss_mts[loss_key].update(loss.item(), batch_size)
|
|
840
|
+
for extra_loss_key, extra_loss in extra_losses.items():
|
|
841
|
+
if extra_loss_key not in extra_loss_mts:
|
|
842
|
+
extra_loss_mts[extra_loss_key] = AverageMeter(extra_loss_key, self.device, ":.2e")
|
|
843
|
+
extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
|
|
844
|
+
|
|
845
|
+
def _log_meters_and_save_best_ckpts(self, phases: List[str]):
|
|
846
|
+
logging.info("Synchronizing meters")
|
|
847
|
+
out_dict = {}
|
|
848
|
+
checkpoint_save_keys = []
|
|
849
|
+
for key, meter in self._get_meters(phases).items():
|
|
850
|
+
meter_output = meter.compute_synced()
|
|
851
|
+
is_better_check = getattr(meter, "is_better", None)
|
|
852
|
+
|
|
853
|
+
for meter_subkey, meter_value in meter_output.items():
|
|
854
|
+
out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
|
|
855
|
+
|
|
856
|
+
if is_better_check is None:
|
|
857
|
+
continue
|
|
858
|
+
|
|
859
|
+
tracked_meter_key = os.path.join(key, meter_subkey)
|
|
860
|
+
if tracked_meter_key not in self.best_meter_values or is_better_check(
|
|
861
|
+
meter_value,
|
|
862
|
+
self.best_meter_values[tracked_meter_key],
|
|
863
|
+
):
|
|
864
|
+
self.best_meter_values[tracked_meter_key] = meter_value
|
|
865
|
+
|
|
866
|
+
if (
|
|
867
|
+
self.checkpoint_conf.save_best_meters is not None
|
|
868
|
+
and key in self.checkpoint_conf.save_best_meters
|
|
869
|
+
):
|
|
870
|
+
checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
|
|
871
|
+
|
|
872
|
+
if len(checkpoint_save_keys) > 0:
|
|
873
|
+
self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
|
|
874
|
+
|
|
875
|
+
return out_dict
|
|
876
|
+
|
|
877
|
+
def _log_timers(self, phase):
|
|
878
|
+
time_remaining = 0
|
|
879
|
+
epochs_remaining = self.max_epochs - self.epoch - 1
|
|
880
|
+
val_epochs_remaining = sum(n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs))
|
|
881
|
+
|
|
882
|
+
# Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
|
|
883
|
+
# the end epoch.
|
|
884
|
+
if (self.max_epochs - 1) % self.val_epoch_freq != 0:
|
|
885
|
+
val_epochs_remaining += 1
|
|
886
|
+
|
|
887
|
+
# Remove the current val run from estimate
|
|
888
|
+
if phase == Phase.VAL:
|
|
889
|
+
val_epochs_remaining -= 1
|
|
890
|
+
|
|
891
|
+
time_remaining += (
|
|
892
|
+
epochs_remaining * self.est_epoch_time[Phase.TRAIN] + val_epochs_remaining * self.est_epoch_time[Phase.VAL]
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
self.logger.log(
|
|
896
|
+
os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
|
|
897
|
+
self.time_elapsed_meter.val,
|
|
898
|
+
self.steps[phase],
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
|
|
902
|
+
|
|
903
|
+
def _reset_meters(self, phases: str) -> None:
|
|
904
|
+
for meter in self._get_meters(phases).values():
|
|
905
|
+
meter.reset()
|
|
906
|
+
|
|
907
|
+
def _check_val_key_match(self, val_keys, phase):
|
|
908
|
+
if val_keys is not None:
|
|
909
|
+
# Check if there are any duplicates
|
|
910
|
+
assert len(val_keys) == len(set(val_keys)), f"Duplicate keys in val datasets, keys: {val_keys}"
|
|
911
|
+
|
|
912
|
+
# Check that the keys match the meter keys
|
|
913
|
+
if self.meters_conf is not None and phase in self.meters_conf:
|
|
914
|
+
assert set(val_keys) == set(self.meters_conf[phase].keys()), (
|
|
915
|
+
f"Keys in val datasets do not match the keys in meters."
|
|
916
|
+
f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
|
|
917
|
+
f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
if self.loss_conf is not None:
|
|
921
|
+
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
|
922
|
+
assert all([k in loss_keys for k in val_keys]), (
|
|
923
|
+
f"Keys in val datasets do not match the keys in losses."
|
|
924
|
+
f"\nMissing in losses: {set(val_keys) - loss_keys}"
|
|
925
|
+
f"\nMissing in val datasets: {loss_keys - set(val_keys)}"
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
def _setup_components(self):
|
|
929
|
+
|
|
930
|
+
# Get the keys for all the val datasets, if any
|
|
931
|
+
val_phase = Phase.VAL
|
|
932
|
+
val_keys = None
|
|
933
|
+
if self.data_conf.get(val_phase, None) is not None:
|
|
934
|
+
val_keys = collect_dict_keys(self.data_conf[val_phase])
|
|
935
|
+
# Additional checks on the sanity of the config for val datasets
|
|
936
|
+
self._check_val_key_match(val_keys, phase=val_phase)
|
|
937
|
+
|
|
938
|
+
logging.info("Setting up components: Model, loss, optim, meters etc.")
|
|
939
|
+
self.epoch = 0
|
|
940
|
+
self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
|
|
941
|
+
|
|
942
|
+
self.logger = Logger(self.logging_conf)
|
|
943
|
+
|
|
944
|
+
self.model = instantiate(self.model_conf, _convert_="all")
|
|
945
|
+
print_model_summary(self.model)
|
|
946
|
+
|
|
947
|
+
self.loss = None
|
|
948
|
+
if self.loss_conf:
|
|
949
|
+
self.loss = {
|
|
950
|
+
key: el for (key, el) in instantiate(self.loss_conf, _convert_="all").items() # wrap_base_loss(el)
|
|
951
|
+
}
|
|
952
|
+
self.loss = nn.ModuleDict(self.loss)
|
|
953
|
+
|
|
954
|
+
self.meters = {}
|
|
955
|
+
self.best_meter_values = {}
|
|
956
|
+
if self.meters_conf:
|
|
957
|
+
self.meters = instantiate(self.meters_conf, _convert_="all")
|
|
958
|
+
|
|
959
|
+
self.scaler = torch.amp.GradScaler(
|
|
960
|
+
self.device,
|
|
961
|
+
enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
self.gradient_clipper = instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
|
|
965
|
+
self.gradient_logger = instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
|
|
966
|
+
|
|
967
|
+
logging.info("Finished setting up components: Model, loss, optim, meters etc.")
|
|
968
|
+
|
|
969
|
+
def _construct_optimizers(self):
|
|
970
|
+
self.optim = construct_optimizer(
|
|
971
|
+
self.model,
|
|
972
|
+
self.optim_conf.optimizer,
|
|
973
|
+
self.optim_conf.options,
|
|
974
|
+
self.optim_conf.param_group_modifiers,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
|
|
978
|
+
core_loss = loss.pop(CORE_LOSS_KEY)
|
|
979
|
+
if step % self.logging_conf.log_scalar_frequency == 0:
|
|
980
|
+
for k in loss:
|
|
981
|
+
log_str = os.path.join(loss_str, k)
|
|
982
|
+
self.logger.log(log_str, loss[k], step)
|
|
983
|
+
return core_loss
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
|
|
987
|
+
"""
|
|
988
|
+
Prints the model and the number of parameters in the model.
|
|
989
|
+
# Multiple packages provide this info in a nice table format
|
|
990
|
+
# However, they need us to provide an `input` (as they also write down the output sizes)
|
|
991
|
+
# Our models are complex, and a single input is restrictive.
|
|
992
|
+
# https://github.com/sksq96/pytorch-summary
|
|
993
|
+
# https://github.com/nmhkahn/torchsummaryX
|
|
994
|
+
"""
|
|
995
|
+
if get_rank() != 0:
|
|
996
|
+
return
|
|
997
|
+
param_kwargs = {}
|
|
998
|
+
trainable_parameters = sum(p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad)
|
|
999
|
+
total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
|
|
1000
|
+
non_trainable_parameters = total_parameters - trainable_parameters
|
|
1001
|
+
logging.info("==" * 10)
|
|
1002
|
+
logging.info(f"Summary for model {type(model)}")
|
|
1003
|
+
logging.info(f"Model is {model}")
|
|
1004
|
+
logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
|
|
1005
|
+
logging.info(f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}")
|
|
1006
|
+
logging.info(f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}")
|
|
1007
|
+
logging.info("==" * 10)
|
|
1008
|
+
|
|
1009
|
+
if log_dir:
|
|
1010
|
+
output_fpath = os.path.join(log_dir, "model.txt")
|
|
1011
|
+
with g_pathmgr.open(output_fpath, "w") as f:
|
|
1012
|
+
print(model, file=f)
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def get_human_readable_count(number: int) -> str:
|
|
1019
|
+
"""
|
|
1020
|
+
Abbreviates an integer number with K, M, B, T for thousands, millions,
|
|
1021
|
+
billions and trillions, respectively.
|
|
1022
|
+
Examples:
|
|
1023
|
+
>>> get_human_readable_count(123)
|
|
1024
|
+
'123 '
|
|
1025
|
+
>>> get_human_readable_count(1234) # (one thousand)
|
|
1026
|
+
'1.2 K'
|
|
1027
|
+
>>> get_human_readable_count(2e6) # (two million)
|
|
1028
|
+
'2.0 M'
|
|
1029
|
+
>>> get_human_readable_count(3e9) # (three billion)
|
|
1030
|
+
'3.0 B'
|
|
1031
|
+
>>> get_human_readable_count(4e14) # (four hundred trillion)
|
|
1032
|
+
'400 T'
|
|
1033
|
+
>>> get_human_readable_count(5e15) # (more than trillion)
|
|
1034
|
+
'5,000 T'
|
|
1035
|
+
Args:
|
|
1036
|
+
number: a positive integer number
|
|
1037
|
+
Return:
|
|
1038
|
+
A string formatted according to the pattern described above.
|
|
1039
|
+
"""
|
|
1040
|
+
assert number >= 0
|
|
1041
|
+
labels = PARAMETER_NUM_UNITS
|
|
1042
|
+
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
|
|
1043
|
+
num_groups = int(np.ceil(num_digits / 3))
|
|
1044
|
+
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
|
|
1045
|
+
shift = -3 * (num_groups - 1)
|
|
1046
|
+
number = number * (10**shift)
|
|
1047
|
+
index = num_groups - 1
|
|
1048
|
+
if index < 1 or number >= 100:
|
|
1049
|
+
return f"{int(number):,d} {labels[index]}"
|
|
1050
|
+
else:
|
|
1051
|
+
return f"{number:,.1f} {labels[index]}"
|