TinyExp 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinyexp/__init__.py +149 -0
- tinyexp/dataset/__init__.py +7 -0
- tinyexp/dataset/fake_dataloader.py +25 -0
- tinyexp/dataset/sampler.py +63 -0
- tinyexp/examples/__init__.py +0 -0
- tinyexp/examples/mnist_exp.py +249 -0
- tinyexp/examples/resnet_exp.py +431 -0
- tinyexp/exceptions.py +46 -0
- tinyexp/tiny_engine/accelerator/__init__.py +12 -0
- tinyexp/tiny_engine/accelerator/base_accelerator.py +71 -0
- tinyexp/tiny_engine/accelerator/cpu_accelerator.py +79 -0
- tinyexp/tiny_engine/accelerator/ddp_accelerator.py +154 -0
- tinyexp/tiny_engine/accelerator/hf_accelerator.py +21 -0
- tinyexp/utils/__init__.py +0 -0
- tinyexp/utils/log_utils.py +48 -0
- tinyexp/utils/model_utils.py +17 -0
- tinyexp/utils/ray_utils.py +183 -0
- tinyexp/utils/redis_utils.py +231 -0
- tinyexp-0.0.1.dist-info/METADATA +139 -0
- tinyexp-0.0.1.dist-info/RECORD +22 -0
- tinyexp-0.0.1.dist-info/WHEEL +4 -0
- tinyexp-0.0.1.dist-info/licenses/LICENSE +21 -0
tinyexp/__init__.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
__author__ = "LI Zeming"
|
|
2
|
+
__email__ = "zane.li@connect.ust.hk"
|
|
3
|
+
__license__ = "MIT"
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from hydra.conf import HydraConf, RunDir
|
|
10
|
+
from hydra.core.config_store import ConfigStore
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from omegaconf.listconfig import ListConfig
|
|
13
|
+
|
|
14
|
+
from .exceptions import UnknownConfigurationKeyError
|
|
15
|
+
from .utils.log_utils import tiny_logger_setup
|
|
16
|
+
from .utils.ray_utils import simple_launch_exp
|
|
17
|
+
|
|
18
|
+
__all__ = ["ConfigStore", "RedisCfgMixin", "TinyExp", "simple_launch_exp"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class _HydraConfig(HydraConf):
|
|
23
|
+
"""
|
|
24
|
+
To avoid hydra output the config in unexpected directory.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
output_subdir: Optional[str] = None
|
|
28
|
+
run: RunDir = field(default_factory=lambda: RunDir("./output"))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class TinyExp:
|
|
33
|
+
"""
|
|
34
|
+
Simple experiment configuration class, which use hydra to manage and override configurations.
|
|
35
|
+
The core idea is to provide a unified interface for experiment configurations, which can be instantiated
|
|
36
|
+
and used in various contexts, such as Ray or TorchRun.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
hydra: _HydraConfig = field(default_factory=_HydraConfig)
|
|
40
|
+
|
|
41
|
+
# ---------------- luancher configuration ---------------- #
|
|
42
|
+
num_worker: int = -1 # Number of workers, -1 means to be determined by the user
|
|
43
|
+
num_gpus_per_worker: float = 1.0 # Number of GPUs per worker, should be a float value between 0 and 1.
|
|
44
|
+
|
|
45
|
+
output_root: str = "./output"
|
|
46
|
+
overrided_cfg: dict = field(default_factory=dict)
|
|
47
|
+
|
|
48
|
+
def __repr__(self):
|
|
49
|
+
# Customize the representation of the Exp object for cleaner Ray logs.
|
|
50
|
+
return f"Exp(rank={os.getenv('RANK', 'N/A')})"
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class WandbCfg:
|
|
54
|
+
enable_wandb: bool = False
|
|
55
|
+
|
|
56
|
+
def build_wandb(self, accelerator=None, **kwargs):
|
|
57
|
+
if self.enable_wandb:
|
|
58
|
+
import wandb
|
|
59
|
+
|
|
60
|
+
if accelerator is None or accelerator.rank == 0:
|
|
61
|
+
wandb.init(**kwargs)
|
|
62
|
+
return wandb
|
|
63
|
+
|
|
64
|
+
wandb_cfg: WandbCfg = field(default_factory=WandbCfg)
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class LoggerCfg:
|
|
68
|
+
def build_logger(self, save_dir: str, distributed_rank: int = 0, filename: str = "log.txt", mode: str = "a"):
|
|
69
|
+
logger = tiny_logger_setup(save_dir, distributed_rank, filename, mode)
|
|
70
|
+
logger.info(f"==> log file: {os.path.join(save_dir, filename)}")
|
|
71
|
+
return logger
|
|
72
|
+
|
|
73
|
+
logger_cfg: LoggerCfg = field(default_factory=LoggerCfg)
|
|
74
|
+
|
|
75
|
+
def set_cfg(self, cfg_hydra, cfg_object=None):
|
|
76
|
+
if cfg_object is None:
|
|
77
|
+
cfg_object = self
|
|
78
|
+
for key, value in cfg_hydra.items():
|
|
79
|
+
if hasattr(cfg_object, key):
|
|
80
|
+
if isinstance(value, (DictConfig, dict)):
|
|
81
|
+
# If the value is a dictionary, recursively set attributes
|
|
82
|
+
sub_object = getattr(cfg_object, key)
|
|
83
|
+
self.set_cfg(value, sub_object)
|
|
84
|
+
else:
|
|
85
|
+
# Otherwise, set the attribute directly
|
|
86
|
+
ori_value = getattr(cfg_object, key, None)
|
|
87
|
+
if ori_value != value:
|
|
88
|
+
if os.getenv("RANK", 0) == 0 or os.getenv("RANK", 0) == "0":
|
|
89
|
+
print(f"{key}: {value} <-- {ori_value}(original)")
|
|
90
|
+
# print(f"Override {key} from {ori_value} to {value} in {cfg_object.__class__.__name__}")
|
|
91
|
+
setattr(cfg_object, key, value)
|
|
92
|
+
self.overrided_cfg[key] = value
|
|
93
|
+
else:
|
|
94
|
+
raise UnknownConfigurationKeyError(key)
|
|
95
|
+
return cfg_object
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class RedisCfgMixin:
|
|
100
|
+
@dataclass
|
|
101
|
+
class RedisCacheCfg:
|
|
102
|
+
redis_cache_enabled: bool = True
|
|
103
|
+
redis_cache_shard_ports: ListConfig = field(
|
|
104
|
+
default_factory=lambda: ListConfig(
|
|
105
|
+
[
|
|
106
|
+
7000,
|
|
107
|
+
7001,
|
|
108
|
+
7002,
|
|
109
|
+
7003,
|
|
110
|
+
7004,
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
) # List of Redis cache shard used ports
|
|
114
|
+
redis_cache_max_memory: int = 160 # Maximum memory is 160GB, according to the ImageNet dataset size
|
|
115
|
+
redis_cluster_manager_cpus: int = 10
|
|
116
|
+
|
|
117
|
+
def build_redis_cache(self):
|
|
118
|
+
if self.redis_cache_enabled:
|
|
119
|
+
from tinyexp.utils.redis_utils import RedisClusterManager
|
|
120
|
+
|
|
121
|
+
redis_cluster_manager = RedisClusterManager(
|
|
122
|
+
ports=self.redis_cache_shard_ports,
|
|
123
|
+
max_memory_per_port=self.redis_cache_max_memory // len(self.redis_cache_shard_ports),
|
|
124
|
+
)
|
|
125
|
+
return redis_cluster_manager.start_redis_cluster()
|
|
126
|
+
return True
|
|
127
|
+
|
|
128
|
+
redis_cache_cfg: RedisCacheCfg = field(default_factory=RedisCacheCfg)
|
|
129
|
+
|
|
130
|
+
def proxy_build_redis_cache(self):
|
|
131
|
+
"""
|
|
132
|
+
Hard-coded method to build Redis cache since ray actor need
|
|
133
|
+
"""
|
|
134
|
+
return self.redis_cache_cfg.build_redis_cache()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def store_and_run_exp(exp_class: type[TinyExp]) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Extract the config from the exp_class and store it in the ConfigStore(hydra config store).
|
|
140
|
+
Then launch the experiment with the config.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
exp_class: The class of the experiment to run.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
None: This function does not return anything.
|
|
147
|
+
"""
|
|
148
|
+
ConfigStore.instance().store(name="cfg", node=exp_class)
|
|
149
|
+
simple_launch_exp()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
__all__ = ["HoldOnesampleDataLoader"]
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class HoldOnesampleDataLoader:
|
|
7
|
+
"""
|
|
8
|
+
A fake dataloader that holds one sample from the original dataloader.
|
|
9
|
+
This is useful for testing and profiling purposes, where we want to test the model with a single in-memory sample.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, dataloader):
|
|
13
|
+
self.dataloader = dataloader
|
|
14
|
+
for data in self.dataloader:
|
|
15
|
+
self.sample = data
|
|
16
|
+
break
|
|
17
|
+
|
|
18
|
+
def __iter__(self):
|
|
19
|
+
return self
|
|
20
|
+
|
|
21
|
+
def __next__(self):
|
|
22
|
+
return copy.deepcopy(self.sample)
|
|
23
|
+
|
|
24
|
+
def __len__(self):
|
|
25
|
+
return len(self.dataloader)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils.data.sampler import Sampler
|
|
6
|
+
|
|
7
|
+
__all__ = ["InfiniteSampler"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InfiniteSampler(Sampler):
|
|
11
|
+
"""
|
|
12
|
+
In training, we only care about the "infinite stream" of training data.
|
|
13
|
+
So this sampler produces an infinite stream of indices and
|
|
14
|
+
all workers cooperate to correctly shuffle the indices and sample different indices.
|
|
15
|
+
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
|
|
16
|
+
where `indices` is an infinite stream of indices consisting of
|
|
17
|
+
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
|
|
18
|
+
or `range(size) + range(size) + ...` (if shuffle is False)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = 0, drop_last=False, accelerator=None):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
size (int): the total number of data of the underlying dataset to sample from
|
|
25
|
+
shuffle (bool): whether to shuffle the indices or not
|
|
26
|
+
seed (int): the initial seed of the shuffle. Must be the same
|
|
27
|
+
across all workers. If None, will use a random seed shared
|
|
28
|
+
among workers (require synchronization among all workers).
|
|
29
|
+
drop_last (bool): whether to drop the last incomplete batch
|
|
30
|
+
"""
|
|
31
|
+
self._size = size
|
|
32
|
+
self._shuffle = shuffle
|
|
33
|
+
self._seed = int(seed)
|
|
34
|
+
self.drop_last = drop_last
|
|
35
|
+
if accelerator is not None:
|
|
36
|
+
self._rank = accelerator.rank
|
|
37
|
+
self._world_size = accelerator.world_size
|
|
38
|
+
else:
|
|
39
|
+
self._rank = 0
|
|
40
|
+
self._world_size = 1
|
|
41
|
+
|
|
42
|
+
def set_epoch(self, epoch):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def __iter__(self):
|
|
46
|
+
start = self._rank
|
|
47
|
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
|
48
|
+
|
|
49
|
+
def _infinite_indices(self):
|
|
50
|
+
g = torch.Generator()
|
|
51
|
+
g.manual_seed(self._seed)
|
|
52
|
+
while True:
|
|
53
|
+
if self._shuffle:
|
|
54
|
+
yield from torch.randperm(self._size, generator=g).tolist()
|
|
55
|
+
else:
|
|
56
|
+
# yield from torch.arange(self._size)
|
|
57
|
+
yield from list(range(self._size))
|
|
58
|
+
|
|
59
|
+
def __len__(self):
|
|
60
|
+
if self.drop_last:
|
|
61
|
+
return self._size // self._world_size
|
|
62
|
+
else:
|
|
63
|
+
return (self._size + self._world_size - 1) // self._world_size
|
|
File without changes
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import torch.optim as optim
|
|
9
|
+
import wandb
|
|
10
|
+
from omegaconf import OmegaConf
|
|
11
|
+
from torch.optim.lr_scheduler import StepLR
|
|
12
|
+
from torchvision import datasets, transforms
|
|
13
|
+
|
|
14
|
+
from tinyexp import TinyExp, store_and_run_exp
|
|
15
|
+
from tinyexp.exceptions import UnknownAcceleratorTypeError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Net(nn.Module):
|
|
19
|
+
def __init__(self) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
|
22
|
+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
|
23
|
+
self.dropout1 = nn.Dropout(0.25)
|
|
24
|
+
self.dropout2 = nn.Dropout(0.5)
|
|
25
|
+
self.fc1 = nn.Linear(9216, 128)
|
|
26
|
+
self.fc2 = nn.Linear(128, 10)
|
|
27
|
+
self.loss = F.nll_loss
|
|
28
|
+
|
|
29
|
+
def forward(self, x, target=None, onnx_exporting=False) -> torch.Tensor:
|
|
30
|
+
x = self.conv1(x)
|
|
31
|
+
x = F.relu(x)
|
|
32
|
+
x = self.conv2(x)
|
|
33
|
+
x = F.relu(x)
|
|
34
|
+
x = F.max_pool2d(x, 2)
|
|
35
|
+
x = self.dropout1(x)
|
|
36
|
+
x = torch.flatten(x, 1)
|
|
37
|
+
x = self.fc1(x)
|
|
38
|
+
x = F.relu(x)
|
|
39
|
+
x = self.dropout2(x)
|
|
40
|
+
x = self.fc2(x)
|
|
41
|
+
if onnx_exporting:
|
|
42
|
+
return x
|
|
43
|
+
output = F.log_softmax(x, dim=1)
|
|
44
|
+
|
|
45
|
+
if self.training and target is not None:
|
|
46
|
+
return self.loss(output, target)
|
|
47
|
+
else:
|
|
48
|
+
return output
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(repr=False)
|
|
52
|
+
class Exp(TinyExp):
|
|
53
|
+
exp_class: str = f"{__name__}.Exp"
|
|
54
|
+
num_worker: int = 2
|
|
55
|
+
num_gpus_per_worker: float = 0.0
|
|
56
|
+
mode: str = "train"
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class AcceleratorCfg:
|
|
60
|
+
accelerator: str = "cpu"
|
|
61
|
+
|
|
62
|
+
def build_accelerator(self):
|
|
63
|
+
from tinyexp.tiny_engine.accelerator import CPUAccelerator, DDPAccelerator
|
|
64
|
+
|
|
65
|
+
if self.accelerator == "cpu":
|
|
66
|
+
accelerator = CPUAccelerator()
|
|
67
|
+
elif self.accelerator == "ddp":
|
|
68
|
+
accelerator = DDPAccelerator()
|
|
69
|
+
else:
|
|
70
|
+
raise UnknownAcceleratorTypeError(self.accelerator)
|
|
71
|
+
return accelerator
|
|
72
|
+
|
|
73
|
+
accelerator_cfg: AcceleratorCfg = field(default_factory=AcceleratorCfg)
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class DataloaderCfg:
|
|
77
|
+
data_root: str = "./data/"
|
|
78
|
+
train_batch_size_per_device: int = 32
|
|
79
|
+
train_data_worker_per_gpu: int = 2
|
|
80
|
+
val_data_worker_per_gpu: int = 1
|
|
81
|
+
|
|
82
|
+
def build_train_dataloader(self, accelerator):
|
|
83
|
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
84
|
+
ds_train = datasets.MNIST(self.data_root, train=True, download=True, transform=transform)
|
|
85
|
+
sampler = torch.utils.data.DistributedSampler(
|
|
86
|
+
ds_train, num_replicas=accelerator.world_size, rank=accelerator.rank
|
|
87
|
+
)
|
|
88
|
+
dl_train = torch.utils.data.DataLoader(
|
|
89
|
+
ds_train,
|
|
90
|
+
batch_size=self.train_batch_size_per_device,
|
|
91
|
+
shuffle=False,
|
|
92
|
+
num_workers=self.train_data_worker_per_gpu,
|
|
93
|
+
drop_last=True,
|
|
94
|
+
sampler=sampler,
|
|
95
|
+
)
|
|
96
|
+
return dl_train
|
|
97
|
+
|
|
98
|
+
def build_val_dataloader(self, accelerator):
|
|
99
|
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
100
|
+
ds_val = datasets.MNIST(self.data_root, train=False, download=True, transform=transform)
|
|
101
|
+
sampler = torch.utils.data.DistributedSampler(
|
|
102
|
+
ds_val, num_replicas=accelerator.world_size, rank=accelerator.rank
|
|
103
|
+
)
|
|
104
|
+
dl_val = torch.utils.data.DataLoader(
|
|
105
|
+
ds_val,
|
|
106
|
+
batch_size=self.train_batch_size_per_device,
|
|
107
|
+
shuffle=False,
|
|
108
|
+
num_workers=self.val_data_worker_per_gpu,
|
|
109
|
+
drop_last=True,
|
|
110
|
+
sampler=sampler,
|
|
111
|
+
)
|
|
112
|
+
return dl_val
|
|
113
|
+
|
|
114
|
+
dataloader_cfg: DataloaderCfg = field(default_factory=DataloaderCfg)
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class OptimizerCfg:
|
|
118
|
+
lr_per_img: float = 1.0 / 64.0 # single image learning rate
|
|
119
|
+
|
|
120
|
+
def build_optimizer(self, module, dataloader, accelerator):
|
|
121
|
+
optimizer = optim.Adadelta(
|
|
122
|
+
module.parameters(),
|
|
123
|
+
lr=self.lr_per_img * dataloader.batch_size * accelerator.world_size,
|
|
124
|
+
)
|
|
125
|
+
return optimizer
|
|
126
|
+
|
|
127
|
+
optimizer_cfg: OptimizerCfg = field(default_factory=OptimizerCfg)
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class ModuleCfg:
|
|
131
|
+
def build_module(self):
|
|
132
|
+
return Net()
|
|
133
|
+
|
|
134
|
+
module_cfg: ModuleCfg = field(default_factory=ModuleCfg)
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class LrSchedulerCfg:
|
|
138
|
+
def build_lr_scheduler(self, optimizer):
|
|
139
|
+
return StepLR(optimizer, step_size=1, gamma=0.7)
|
|
140
|
+
|
|
141
|
+
lr_scheduler_cfg: LrSchedulerCfg = field(default_factory=LrSchedulerCfg)
|
|
142
|
+
|
|
143
|
+
# ------------------------------ bellowing is the execution part --------------------- #
|
|
144
|
+
def run(self) -> None:
|
|
145
|
+
accelerator = self.accelerator_cfg.build_accelerator()
|
|
146
|
+
logger = self.logger_cfg.build_logger(
|
|
147
|
+
save_dir=os.path.join(self.output_root, self.__class__.__name__),
|
|
148
|
+
distributed_rank=accelerator.rank,
|
|
149
|
+
)
|
|
150
|
+
cfg_dict = OmegaConf.to_container(OmegaConf.structured(self), resolve=True)
|
|
151
|
+
del cfg_dict["hydra"]
|
|
152
|
+
logger.info(f"-------- Configurations --------\n{OmegaConf.to_yaml(cfg_dict)}")
|
|
153
|
+
|
|
154
|
+
self._train(accelerator=accelerator, logger=logger, cfg_dict=cfg_dict)
|
|
155
|
+
|
|
156
|
+
def _evaluate(self, accelerator, logger, module_or_module_path, val_dataloader=None) -> None:
|
|
157
|
+
if isinstance(module_or_module_path, str):
|
|
158
|
+
module = Net()
|
|
159
|
+
module.load_state_dict(torch.load(module_or_module_path))
|
|
160
|
+
module = accelerator.prepare(module)
|
|
161
|
+
else:
|
|
162
|
+
module = module_or_module_path
|
|
163
|
+
|
|
164
|
+
if val_dataloader is None:
|
|
165
|
+
val_dataloader = self.dataloader_cfg.build_val_dataloader(accelerator)
|
|
166
|
+
|
|
167
|
+
module.eval()
|
|
168
|
+
accurate = torch.tensor(0.0, device=accelerator.device)
|
|
169
|
+
|
|
170
|
+
for batch in val_dataloader:
|
|
171
|
+
features, labels = (item.to(accelerator.device) for item in batch)
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
preds = module(features)
|
|
174
|
+
predictions = preds.argmax(dim=-1)
|
|
175
|
+
accurate_preds = predictions == labels
|
|
176
|
+
accurate_preds_sum = accelerator.reduce_sum(accurate_preds.sum())
|
|
177
|
+
accurate += accurate_preds_sum
|
|
178
|
+
eval_metric = accurate.item() / len(val_dataloader.dataset)
|
|
179
|
+
|
|
180
|
+
accelerator.wait_for_everyone()
|
|
181
|
+
nowtime = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
182
|
+
logger.info(f"{nowtime} --> eval_metric= {100 * eval_metric:.2f}%")
|
|
183
|
+
|
|
184
|
+
if self.wandb_cfg.enable_wandb and accelerator.is_main_process:
|
|
185
|
+
wandb.log({"val_metric": eval_metric})
|
|
186
|
+
|
|
187
|
+
def _train(self, accelerator, logger, cfg_dict) -> None:
|
|
188
|
+
train_dataloader = self.dataloader_cfg.build_train_dataloader(accelerator)
|
|
189
|
+
val_dataloader = self.dataloader_cfg.build_val_dataloader(accelerator)
|
|
190
|
+
ori_module = self.module_cfg.build_module()
|
|
191
|
+
ori_optimizer = self.optimizer_cfg.build_optimizer(ori_module, train_dataloader, accelerator)
|
|
192
|
+
lr_scheduler = self.lr_scheduler_cfg.build_lr_scheduler(ori_optimizer)
|
|
193
|
+
|
|
194
|
+
module, optimizer = accelerator.prepare(ori_module, ori_optimizer)
|
|
195
|
+
accelerator.print(f"device {accelerator.device!s} is used!")
|
|
196
|
+
|
|
197
|
+
train_iter = iter(train_dataloader)
|
|
198
|
+
if self.wandb_cfg.enable_wandb and accelerator.rank == 0:
|
|
199
|
+
self.wandb_cfg.build_wandb(
|
|
200
|
+
accelerator=accelerator,
|
|
201
|
+
project="Baselines",
|
|
202
|
+
config=cfg_dict,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
for epoch in range(3):
|
|
206
|
+
module.train()
|
|
207
|
+
|
|
208
|
+
for step in range(len(train_dataloader)):
|
|
209
|
+
try:
|
|
210
|
+
batch = next(train_iter)
|
|
211
|
+
except StopIteration:
|
|
212
|
+
train_iter = iter(train_dataloader)
|
|
213
|
+
batch = next(train_iter)
|
|
214
|
+
|
|
215
|
+
features, labels = (item.to(accelerator.device) for item in batch)
|
|
216
|
+
preds = module(features)
|
|
217
|
+
loss = nn.CrossEntropyLoss()(preds, labels)
|
|
218
|
+
|
|
219
|
+
optimizer.zero_grad()
|
|
220
|
+
accelerator.backward(loss)
|
|
221
|
+
optimizer.step()
|
|
222
|
+
if (step + 1) % 20 == 0:
|
|
223
|
+
logger.info(f"epoch {epoch} loss: {loss.item(): .4f} lr: {optimizer.param_groups[0]['lr']: .4f}")
|
|
224
|
+
if self.wandb_cfg.enable_wandb and accelerator.rank == 0:
|
|
225
|
+
wandb.log(
|
|
226
|
+
{
|
|
227
|
+
"epoch": epoch,
|
|
228
|
+
"loss": loss.item(),
|
|
229
|
+
"lr": optimizer.param_groups[0]["lr"],
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
self._evaluate(
|
|
233
|
+
accelerator=accelerator, logger=logger, module_or_module_path=module, val_dataloader=val_dataloader
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
lr_scheduler.step()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# import hydra
|
|
240
|
+
# @hydra.main(version_base=None, config_name="cfg")
|
|
241
|
+
# def simple_launch_exp(cfg: DictConfig) -> None:
|
|
242
|
+
# from omegaconf import DictConfig, OmegaConf
|
|
243
|
+
# print(OmegaConf.to_yaml(cfg))
|
|
244
|
+
# from IPython import embed; embed() # for debugging
|
|
245
|
+
# exp_class = hydra.utils.get_class(cfg.exp_class)
|
|
246
|
+
# exp_class().set_cfg(cfg).run()
|
|
247
|
+
|
|
248
|
+
if __name__ == "__main__":
|
|
249
|
+
store_and_run_exp(Exp)
|