plato-learn 1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +1 -0
- plato/algorithms/__init__.py +0 -0
- plato/algorithms/base.py +45 -0
- plato/algorithms/fedavg.py +48 -0
- plato/algorithms/fedavg_gan.py +79 -0
- plato/algorithms/fedavg_personalized.py +48 -0
- plato/algorithms/mistnet.py +52 -0
- plato/algorithms/registry.py +39 -0
- plato/algorithms/split_learning.py +89 -0
- plato/callbacks/__init__.py +0 -0
- plato/callbacks/client.py +56 -0
- plato/callbacks/handler.py +78 -0
- plato/callbacks/server.py +139 -0
- plato/callbacks/trainer.py +124 -0
- plato/client.py +67 -0
- plato/clients/__init__.py +0 -0
- plato/clients/base.py +467 -0
- plato/clients/edge.py +103 -0
- plato/clients/fedavg_personalized.py +40 -0
- plato/clients/mistnet.py +49 -0
- plato/clients/registry.py +43 -0
- plato/clients/self_supervised_learning.py +51 -0
- plato/clients/simple.py +218 -0
- plato/clients/split_learning.py +150 -0
- plato/config.py +339 -0
- plato/datasources/__init__.py +0 -0
- plato/datasources/base.py +123 -0
- plato/datasources/celeba.py +150 -0
- plato/datasources/cifar10.py +87 -0
- plato/datasources/cifar100.py +61 -0
- plato/datasources/cinic10.py +62 -0
- plato/datasources/coco.py +119 -0
- plato/datasources/datalib/__init__.py +0 -0
- plato/datasources/datalib/audio_extraction_tools.py +137 -0
- plato/datasources/datalib/data_utils.py +124 -0
- plato/datasources/datalib/flickr30kE_utils.py +336 -0
- plato/datasources/datalib/frames_extraction_tools.py +254 -0
- plato/datasources/datalib/gym_utils/__init__.py +0 -0
- plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
- plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
- plato/datasources/datalib/modality_extraction_base.py +59 -0
- plato/datasources/datalib/parse_datasets.py +212 -0
- plato/datasources/datalib/refer_utils/__init__.py +0 -0
- plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
- plato/datasources/datalib/tiny_data_tools.py +81 -0
- plato/datasources/datalib/video_transform.py +79 -0
- plato/datasources/emnist.py +64 -0
- plato/datasources/fashion_mnist.py +41 -0
- plato/datasources/feature.py +24 -0
- plato/datasources/feature_dataset.py +15 -0
- plato/datasources/femnist.py +141 -0
- plato/datasources/flickr30k_entities.py +362 -0
- plato/datasources/gym.py +431 -0
- plato/datasources/huggingface.py +165 -0
- plato/datasources/kinetics.py +568 -0
- plato/datasources/mnist.py +44 -0
- plato/datasources/multimodal_base.py +328 -0
- plato/datasources/pascal_voc.py +56 -0
- plato/datasources/purchase.py +94 -0
- plato/datasources/qoenflx.py +127 -0
- plato/datasources/referitgame.py +330 -0
- plato/datasources/registry.py +119 -0
- plato/datasources/self_supervised_learning.py +98 -0
- plato/datasources/stl10.py +103 -0
- plato/datasources/texas.py +94 -0
- plato/datasources/tiny_imagenet.py +64 -0
- plato/datasources/yolov8.py +85 -0
- plato/models/__init__.py +0 -0
- plato/models/cnn_encoder.py +103 -0
- plato/models/dcgan.py +116 -0
- plato/models/general_multilayer.py +254 -0
- plato/models/huggingface.py +27 -0
- plato/models/lenet5.py +113 -0
- plato/models/multilayer.py +90 -0
- plato/models/multimodal/__init__.py +0 -0
- plato/models/multimodal/base_net.py +91 -0
- plato/models/multimodal/blending.py +142 -0
- plato/models/multimodal/fc_net.py +77 -0
- plato/models/multimodal/fusion_net.py +78 -0
- plato/models/multimodal/multimodal_module.py +152 -0
- plato/models/registry.py +99 -0
- plato/models/resnet.py +190 -0
- plato/models/torch_hub.py +19 -0
- plato/models/vgg.py +113 -0
- plato/models/vit.py +166 -0
- plato/models/yolov8.py +22 -0
- plato/processors/__init__.py +0 -0
- plato/processors/base.py +35 -0
- plato/processors/compress.py +46 -0
- plato/processors/decompress.py +48 -0
- plato/processors/feature.py +51 -0
- plato/processors/feature_additive_noise.py +48 -0
- plato/processors/feature_dequantize.py +34 -0
- plato/processors/feature_gaussian.py +17 -0
- plato/processors/feature_laplace.py +15 -0
- plato/processors/feature_quantize.py +34 -0
- plato/processors/feature_randomized_response.py +50 -0
- plato/processors/feature_unbatch.py +39 -0
- plato/processors/inbound_feature_tensors.py +39 -0
- plato/processors/model.py +55 -0
- plato/processors/model_compress.py +34 -0
- plato/processors/model_decompress.py +37 -0
- plato/processors/model_decrypt.py +41 -0
- plato/processors/model_deepcopy.py +21 -0
- plato/processors/model_dequantize.py +18 -0
- plato/processors/model_dequantize_qsgd.py +61 -0
- plato/processors/model_encrypt.py +43 -0
- plato/processors/model_quantize.py +18 -0
- plato/processors/model_quantize_qsgd.py +82 -0
- plato/processors/model_randomized_response.py +34 -0
- plato/processors/outbound_feature_ndarrays.py +38 -0
- plato/processors/pipeline.py +26 -0
- plato/processors/registry.py +124 -0
- plato/processors/structured_pruning.py +57 -0
- plato/processors/unstructured_pruning.py +73 -0
- plato/samplers/__init__.py +0 -0
- plato/samplers/all_inclusive.py +41 -0
- plato/samplers/base.py +31 -0
- plato/samplers/dirichlet.py +81 -0
- plato/samplers/distribution_noniid.py +132 -0
- plato/samplers/iid.py +53 -0
- plato/samplers/label_quantity_noniid.py +119 -0
- plato/samplers/mixed.py +44 -0
- plato/samplers/mixed_label_quantity_noniid.py +128 -0
- plato/samplers/modality_iid.py +42 -0
- plato/samplers/modality_quantity_noniid.py +56 -0
- plato/samplers/orthogonal.py +99 -0
- plato/samplers/registry.py +66 -0
- plato/samplers/sample_quantity_noniid.py +123 -0
- plato/samplers/sampler_utils.py +190 -0
- plato/servers/__init__.py +0 -0
- plato/servers/base.py +1395 -0
- plato/servers/fedavg.py +281 -0
- plato/servers/fedavg_cs.py +335 -0
- plato/servers/fedavg_gan.py +74 -0
- plato/servers/fedavg_he.py +106 -0
- plato/servers/fedavg_personalized.py +57 -0
- plato/servers/mistnet.py +67 -0
- plato/servers/registry.py +52 -0
- plato/servers/split_learning.py +109 -0
- plato/trainers/__init__.py +0 -0
- plato/trainers/base.py +99 -0
- plato/trainers/basic.py +649 -0
- plato/trainers/diff_privacy.py +178 -0
- plato/trainers/gan.py +330 -0
- plato/trainers/huggingface.py +173 -0
- plato/trainers/loss_criterion.py +70 -0
- plato/trainers/lr_schedulers.py +252 -0
- plato/trainers/optimizers.py +53 -0
- plato/trainers/pascal_voc.py +80 -0
- plato/trainers/registry.py +44 -0
- plato/trainers/self_supervised_learning.py +302 -0
- plato/trainers/split_learning.py +305 -0
- plato/trainers/tracking.py +96 -0
- plato/trainers/yolov8.py +41 -0
- plato/utils/__init__.py +0 -0
- plato/utils/count_parameters.py +30 -0
- plato/utils/csv_processor.py +26 -0
- plato/utils/data_loaders.py +148 -0
- plato/utils/decorators.py +24 -0
- plato/utils/fonts.py +23 -0
- plato/utils/homo_enc.py +187 -0
- plato/utils/reinforcement_learning/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/__init__.py +0 -0
- plato/utils/reinforcement_learning/policies/base.py +161 -0
- plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
- plato/utils/reinforcement_learning/policies/registry.py +32 -0
- plato/utils/reinforcement_learning/policies/sac.py +343 -0
- plato/utils/reinforcement_learning/policies/td3.py +485 -0
- plato/utils/reinforcement_learning/rl_agent.py +142 -0
- plato/utils/reinforcement_learning/rl_server.py +113 -0
- plato/utils/rl_env.py +154 -0
- plato/utils/s3.py +141 -0
- plato/utils/trainer_utils.py +21 -0
- plato/utils/unary_encoding.py +47 -0
- plato_learn-1.1.dist-info/METADATA +35 -0
- plato_learn-1.1.dist-info/RECORD +179 -0
- plato_learn-1.1.dist-info/WHEEL +4 -0
- plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,70 @@
|
|
1
|
+
"""
|
2
|
+
Obtaining the loss criterion for training workloads according to the configuration file.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
from torch import nn
|
8
|
+
from lightly import loss
|
9
|
+
|
10
|
+
from plato.config import Config
|
11
|
+
|
12
|
+
|
13
|
+
def get(**kwargs: Union[str, dict]):
|
14
|
+
"""Get a loss function with its name from the configuration file."""
|
15
|
+
registered_loss_criterion = {
|
16
|
+
"L1Loss": nn.L1Loss,
|
17
|
+
"MSELoss": nn.MSELoss,
|
18
|
+
"BCELoss": nn.BCELoss,
|
19
|
+
"BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
|
20
|
+
"NLLLoss": nn.NLLLoss,
|
21
|
+
"PoissonNLLLoss": nn.PoissonNLLLoss,
|
22
|
+
"CrossEntropyLoss": nn.CrossEntropyLoss,
|
23
|
+
"HingeEmbeddingLoss": nn.HingeEmbeddingLoss,
|
24
|
+
"MarginRankingLoss": nn.MarginRankingLoss,
|
25
|
+
"TripletMarginLoss": nn.TripletMarginLoss,
|
26
|
+
"KLDivLoss": nn.KLDivLoss,
|
27
|
+
}
|
28
|
+
|
29
|
+
ssl_loss_criterion = {
|
30
|
+
"NegativeCosineSimilarity": loss.NegativeCosineSimilarity,
|
31
|
+
"NTXentLoss": loss.NTXentLoss,
|
32
|
+
"BarlowTwinsLoss": loss.BarlowTwinsLoss,
|
33
|
+
"DCLLoss": loss.DCLLoss,
|
34
|
+
"DCLWLoss": loss.DCLWLoss,
|
35
|
+
"DINOLoss": loss.DINOLoss,
|
36
|
+
"PMSNCustomLoss": loss.PMSNCustomLoss,
|
37
|
+
"SwaVLoss": loss.SwaVLoss,
|
38
|
+
"PMSNLoss": loss.PMSNLoss,
|
39
|
+
"SymNegCosineSimilarityLoss": loss.SymNegCosineSimilarityLoss,
|
40
|
+
"TiCoLoss": loss.TiCoLoss,
|
41
|
+
"VICRegLoss": loss.VICRegLoss,
|
42
|
+
"VICRegLLoss": loss.VICRegLLoss,
|
43
|
+
"MSNLoss": loss.MSNLoss,
|
44
|
+
}
|
45
|
+
|
46
|
+
registered_loss_criterion.update(ssl_loss_criterion)
|
47
|
+
|
48
|
+
loss_criterion_name = (
|
49
|
+
kwargs["loss_criterion"]
|
50
|
+
if "loss_criterion" in kwargs
|
51
|
+
else (
|
52
|
+
Config().trainer.loss_criterion
|
53
|
+
if hasattr(Config.trainer, "loss_criterion")
|
54
|
+
else "CrossEntropyLoss"
|
55
|
+
)
|
56
|
+
)
|
57
|
+
|
58
|
+
loss_criterion_params = (
|
59
|
+
kwargs["loss_criterion_params"]
|
60
|
+
if "loss_criterion_params" in kwargs
|
61
|
+
else (
|
62
|
+
Config().parameters.loss_criterion._asdict()
|
63
|
+
if hasattr(Config.parameters, "loss_criterion")
|
64
|
+
else {}
|
65
|
+
)
|
66
|
+
)
|
67
|
+
|
68
|
+
loss_criterion = registered_loss_criterion.get(loss_criterion_name)
|
69
|
+
|
70
|
+
return loss_criterion(**loss_criterion_params)
|
@@ -0,0 +1,252 @@
|
|
1
|
+
"""
|
2
|
+
Returns a learning rate scheduler according to the configuration.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import bisect
|
6
|
+
import sys
|
7
|
+
from types import SimpleNamespace
|
8
|
+
from typing import Union
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
from timm import scheduler
|
12
|
+
from torch import optim
|
13
|
+
|
14
|
+
from plato.config import Config
|
15
|
+
|
16
|
+
|
17
|
+
def get(
|
18
|
+
optimizer: optim.Optimizer, iterations_per_epoch: int, **kwargs: Union[str, dict]
|
19
|
+
):
|
20
|
+
"""Returns a learning rate scheduler according to the configuration."""
|
21
|
+
|
22
|
+
registered_schedulers = {
|
23
|
+
"CosineAnnealingLR": optim.lr_scheduler.CosineAnnealingLR,
|
24
|
+
"LambdaLR": optim.lr_scheduler.LambdaLR,
|
25
|
+
"MultiStepLR": optim.lr_scheduler.MultiStepLR,
|
26
|
+
"StepLR": optim.lr_scheduler.StepLR,
|
27
|
+
"ReduceLROnPlateau": optim.lr_scheduler.ReduceLROnPlateau,
|
28
|
+
"ConstantLR": optim.lr_scheduler.ConstantLR,
|
29
|
+
"LinearLR": optim.lr_scheduler.LinearLR,
|
30
|
+
"ExponentialLR": optim.lr_scheduler.ExponentialLR,
|
31
|
+
"CyclicLR": optim.lr_scheduler.CyclicLR,
|
32
|
+
"CosineAnnealingWarmRestarts": optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
33
|
+
}
|
34
|
+
|
35
|
+
registered_factories = {
|
36
|
+
"timm": scheduler.create_scheduler,
|
37
|
+
}
|
38
|
+
|
39
|
+
_scheduler = (
|
40
|
+
kwargs["lr_scheduler"]
|
41
|
+
if "lr_scheduler" in kwargs
|
42
|
+
else Config().trainer.lr_scheduler
|
43
|
+
)
|
44
|
+
lr_params = (
|
45
|
+
kwargs["lr_params"]
|
46
|
+
if "lr_params" in kwargs
|
47
|
+
else Config().parameters.learning_rate._asdict()
|
48
|
+
)
|
49
|
+
|
50
|
+
# First, look up the registered factories of LR schedulers
|
51
|
+
if _scheduler in registered_factories:
|
52
|
+
scheduler_args = SimpleNamespace(**lr_params)
|
53
|
+
scheduler_args.epochs = Config().trainer.epochs
|
54
|
+
lr_scheduler, __ = registered_factories[_scheduler](
|
55
|
+
args=scheduler_args, optimizer=optimizer
|
56
|
+
)
|
57
|
+
return lr_scheduler
|
58
|
+
|
59
|
+
# The list containing the learning rate schedulers that must be returned or
|
60
|
+
# the learning rate schedulers that ChainedScheduler or SequentialLR will
|
61
|
+
# take as an argument.
|
62
|
+
returned_schedulers = []
|
63
|
+
|
64
|
+
use_chained = False
|
65
|
+
use_sequential = False
|
66
|
+
if "ChainedScheduler" in _scheduler:
|
67
|
+
use_chained = True
|
68
|
+
lr_scheduler = [
|
69
|
+
sched for sched in _scheduler.split(",") if sched != ("ChainedScheduler")
|
70
|
+
]
|
71
|
+
elif "SequentialLR" in _scheduler:
|
72
|
+
use_sequential = True
|
73
|
+
lr_scheduler = [
|
74
|
+
sched for sched in _scheduler.split(",") if sched != ("SequentialLR")
|
75
|
+
]
|
76
|
+
else:
|
77
|
+
lr_scheduler = [_scheduler]
|
78
|
+
|
79
|
+
for _scheduler in lr_scheduler:
|
80
|
+
retrieved_scheduler = registered_schedulers.get(_scheduler)
|
81
|
+
|
82
|
+
if retrieved_scheduler is None:
|
83
|
+
sys.exit("Error: Unknown learning rate scheduler.")
|
84
|
+
|
85
|
+
if _scheduler == "CosineAnnealingLR":
|
86
|
+
returned_schedulers.append(
|
87
|
+
retrieved_scheduler(
|
88
|
+
optimizer, iterations_per_epoch * Config().trainer.epochs
|
89
|
+
)
|
90
|
+
)
|
91
|
+
elif _scheduler == "LambdaLR":
|
92
|
+
lambdas = [lambda it: 1.0]
|
93
|
+
|
94
|
+
if "gamma" in lr_params and "milestone_steps" in lr_params:
|
95
|
+
milestones = [
|
96
|
+
Step.from_str(x, iterations_per_epoch).iteration
|
97
|
+
for x in lr_params["milestone_steps"].split(",")
|
98
|
+
]
|
99
|
+
lambdas.append(
|
100
|
+
lambda it, milestones=milestones: lr_params["gamma"]
|
101
|
+
** bisect.bisect(milestones, it)
|
102
|
+
)
|
103
|
+
|
104
|
+
# Add a linear learning rate warmup if specified
|
105
|
+
if "warmup_steps" in lr_params:
|
106
|
+
warmup_iters = Step.from_str(
|
107
|
+
lr_params["warmup_steps"], iterations_per_epoch
|
108
|
+
).iteration
|
109
|
+
lambdas.append(
|
110
|
+
lambda it, warmup_iters=warmup_iters: min(1.0, it / warmup_iters)
|
111
|
+
)
|
112
|
+
returned_schedulers.append(
|
113
|
+
retrieved_scheduler(
|
114
|
+
optimizer,
|
115
|
+
lambda it, lambdas=lambdas: np.prod([l(it) for l in lambdas]),
|
116
|
+
)
|
117
|
+
)
|
118
|
+
elif _scheduler == "MultiStepLR":
|
119
|
+
milestones = [
|
120
|
+
int(x.split("ep")[0]) for x in lr_params["milestone_steps"].split(",")
|
121
|
+
]
|
122
|
+
returned_schedulers.append(
|
123
|
+
retrieved_scheduler(
|
124
|
+
optimizer, milestones=milestones, gamma=lr_params["gamma"]
|
125
|
+
)
|
126
|
+
)
|
127
|
+
else:
|
128
|
+
returned_schedulers.append(retrieved_scheduler(optimizer, **lr_params))
|
129
|
+
|
130
|
+
if use_chained:
|
131
|
+
return optim.lr_scheduler.ChainedScheduler(returned_schedulers)
|
132
|
+
|
133
|
+
if use_sequential:
|
134
|
+
sequential_milestones = (
|
135
|
+
Config().trainer.lr_sequential_milestones
|
136
|
+
if hasattr(Config().trainer, "lr_sequential_milestones")
|
137
|
+
else 2
|
138
|
+
)
|
139
|
+
sequential_milestones = [
|
140
|
+
int(epoch) for epoch in sequential_milestones.split(",")
|
141
|
+
]
|
142
|
+
|
143
|
+
return optim.lr_scheduler.SequentialLR(
|
144
|
+
optimizer, returned_schedulers, sequential_milestones
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
return returned_schedulers[0]
|
148
|
+
|
149
|
+
|
150
|
+
class Step:
|
151
|
+
"""Represents a particular step of training."""
|
152
|
+
|
153
|
+
def __init__(self, iteration: int, iterations_per_epoch: int) -> None:
|
154
|
+
if iteration < 0:
|
155
|
+
raise ValueError("iteration must >= 0.")
|
156
|
+
if iterations_per_epoch <= 0:
|
157
|
+
raise ValueError("iterations_per_epoch must be > 0.")
|
158
|
+
self._iteration = iteration
|
159
|
+
self.iterations_per_epoch = iterations_per_epoch
|
160
|
+
|
161
|
+
@staticmethod
|
162
|
+
def str_is_zero(s: str) -> bool:
|
163
|
+
return s in ["0ep", "0it", "0ep0it"]
|
164
|
+
|
165
|
+
@staticmethod
|
166
|
+
def from_iteration(iteration: int, iterations_per_epoch: int) -> "Step":
|
167
|
+
return Step(iteration, iterations_per_epoch)
|
168
|
+
|
169
|
+
@staticmethod
|
170
|
+
def from_epoch(epoch: int, iteration: int, iterations_per_epoch: int) -> "Step":
|
171
|
+
return Step(epoch * iterations_per_epoch + iteration, iterations_per_epoch)
|
172
|
+
|
173
|
+
@staticmethod
|
174
|
+
def from_str(s: str, iterations_per_epoch: int) -> "Step":
|
175
|
+
"""Creates a step from a string that describes the number of epochs, iterations, or both.
|
176
|
+
|
177
|
+
Epochs: '120ep'
|
178
|
+
Iterations: '2000it'
|
179
|
+
Both: '120ep50it'"""
|
180
|
+
|
181
|
+
if "ep" in s and "it" in s:
|
182
|
+
ep = int(s.split("ep")[0])
|
183
|
+
it = int(s.split("ep")[1].split("it")[0])
|
184
|
+
if s != "{}ep{}it".format(ep, it):
|
185
|
+
raise ValueError(f"Malformed string step: {s}")
|
186
|
+
return Step.from_epoch(ep, it, iterations_per_epoch)
|
187
|
+
elif "ep" in s:
|
188
|
+
ep = int(s.split("ep")[0])
|
189
|
+
if s != "{}ep".format(ep):
|
190
|
+
raise ValueError(f"Malformed string step: {s}")
|
191
|
+
return Step.from_epoch(ep, 0, iterations_per_epoch)
|
192
|
+
elif "it" in s:
|
193
|
+
it = int(s.split("it")[0])
|
194
|
+
if s != "{}it".format(it):
|
195
|
+
raise ValueError(f"Malformed string step: {s}")
|
196
|
+
return Step.from_iteration(it, iterations_per_epoch)
|
197
|
+
else:
|
198
|
+
raise ValueError(f"Malformed string step: {s}")
|
199
|
+
|
200
|
+
@staticmethod
|
201
|
+
def zero(iterations_per_epoch: int) -> "Step":
|
202
|
+
return Step(0, iterations_per_epoch)
|
203
|
+
|
204
|
+
@property
|
205
|
+
def iteration(self):
|
206
|
+
"""The overall number of steps of training completed so far."""
|
207
|
+
return self._iteration
|
208
|
+
|
209
|
+
@property
|
210
|
+
def ep(self):
|
211
|
+
"""The current epoch of training."""
|
212
|
+
return self._iteration // self.iterations_per_epoch
|
213
|
+
|
214
|
+
@property
|
215
|
+
def it(self):
|
216
|
+
"""The iteration within the current epoch of training."""
|
217
|
+
return self._iteration % self.iterations_per_epoch
|
218
|
+
|
219
|
+
def _check(self, other):
|
220
|
+
if not isinstance(other, Step):
|
221
|
+
raise ValueError(f"Invalid type for other: {other}.")
|
222
|
+
if self.iterations_per_epoch != other.iterations_per_epoch:
|
223
|
+
raise ValueError(
|
224
|
+
"Cannot compare steps when epochs are of different lengths."
|
225
|
+
)
|
226
|
+
|
227
|
+
def __lt__(self, other):
|
228
|
+
self._check(other)
|
229
|
+
return self._iteration < other.iteration
|
230
|
+
|
231
|
+
def __le__(self, other):
|
232
|
+
self._check(other)
|
233
|
+
return self._iteration <= other.iteration
|
234
|
+
|
235
|
+
def __eq__(self, other):
|
236
|
+
self._check(other)
|
237
|
+
return self._iteration == other.iteration
|
238
|
+
|
239
|
+
def __ne__(self, other):
|
240
|
+
self._check(other)
|
241
|
+
return self._iteration != other.iteration
|
242
|
+
|
243
|
+
def __gt__(self, other):
|
244
|
+
self._check(other)
|
245
|
+
return self._iteration > other.iteration
|
246
|
+
|
247
|
+
def __ge__(self, other):
|
248
|
+
self._check(other)
|
249
|
+
return self._iteration >= other.iteration
|
250
|
+
|
251
|
+
def __str__(self):
|
252
|
+
return f"(Iteration {self._iteration}; Iterations per Epoch: {self.iterations_per_epoch})"
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""
|
2
|
+
Optimizers for training workloads.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
import torch_optimizer as torch_optim
|
8
|
+
from torch import optim
|
9
|
+
from timm import optim as timm_optim
|
10
|
+
|
11
|
+
from plato.config import Config
|
12
|
+
|
13
|
+
|
14
|
+
def get(model, **kwargs: Union[str, dict]) -> optim.Optimizer:
|
15
|
+
"""Get an optimizer with its name and parameters obtained from the configuration file."""
|
16
|
+
registered_optimizers = {
|
17
|
+
"Adam": optim.Adam,
|
18
|
+
"Adadelta": optim.Adadelta,
|
19
|
+
"Adagrad": optim.Adagrad,
|
20
|
+
"AdaHessian": torch_optim.Adahessian,
|
21
|
+
"AdamW": optim.AdamW,
|
22
|
+
"SparseAdam": optim.SparseAdam,
|
23
|
+
"Adamax": optim.Adamax,
|
24
|
+
"ASGD": optim.ASGD,
|
25
|
+
"LBFGS": optim.LBFGS,
|
26
|
+
"NAdam": optim.NAdam,
|
27
|
+
"RAdam": optim.RAdam,
|
28
|
+
"RMSprop": optim.RMSprop,
|
29
|
+
"Rprop": optim.Rprop,
|
30
|
+
"SGD": optim.SGD,
|
31
|
+
"LARS": timm_optim.lars.Lars,
|
32
|
+
}
|
33
|
+
|
34
|
+
optimizer_name = (
|
35
|
+
kwargs["optimizer_name"]
|
36
|
+
if "optimizer_name" in kwargs
|
37
|
+
else Config().trainer.optimizer
|
38
|
+
)
|
39
|
+
optimizer_params = (
|
40
|
+
kwargs["optimizer_params"]
|
41
|
+
if "optimizer_params" in kwargs
|
42
|
+
else Config().parameters.optimizer._asdict()
|
43
|
+
)
|
44
|
+
|
45
|
+
# Ensure eps is a float
|
46
|
+
if "eps" in optimizer_params:
|
47
|
+
optimizer_params["eps"] = float(optimizer_params["eps"])
|
48
|
+
|
49
|
+
optimizer = registered_optimizers.get(optimizer_name)
|
50
|
+
if optimizer is not None:
|
51
|
+
return optimizer(model.parameters(), **optimizer_params)
|
52
|
+
|
53
|
+
raise ValueError(f"No such optimizer: {optimizer_name}")
|
@@ -0,0 +1,80 @@
|
|
1
|
+
"""
|
2
|
+
A customized trainer for image segmentation on PASCAL VOC dataset (2012).
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from plato.trainers import basic
|
10
|
+
|
11
|
+
|
12
|
+
class Evaluator(object):
|
13
|
+
def __init__(self, num_class):
|
14
|
+
self.num_class = num_class
|
15
|
+
self.confusion_matrix = np.zeros((self.num_class,) * 2)
|
16
|
+
|
17
|
+
def Mean_Intersection_over_Union(self):
|
18
|
+
MIoU = np.diag(self.confusion_matrix) / (
|
19
|
+
np.sum(self.confusion_matrix, axis=1)
|
20
|
+
+ np.sum(self.confusion_matrix, axis=0)
|
21
|
+
- np.diag(self.confusion_matrix)
|
22
|
+
)
|
23
|
+
MIoU = np.nanmean(MIoU)
|
24
|
+
return MIoU
|
25
|
+
|
26
|
+
def _generate_matrix(self, gt_image, pre_image):
|
27
|
+
mask = (gt_image >= 0) & (gt_image < self.num_class)
|
28
|
+
label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
|
29
|
+
count = np.bincount(label, minlength=self.num_class**2)
|
30
|
+
confusion_matrix = count.reshape(self.num_class, self.num_class)
|
31
|
+
return confusion_matrix
|
32
|
+
|
33
|
+
def add_batch(self, gt_image, pre_image):
|
34
|
+
assert gt_image.shape == pre_image.shape
|
35
|
+
self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
|
36
|
+
|
37
|
+
def reset(self):
|
38
|
+
self.confusion_matrix = np.zeros((self.num_class,) * 2)
|
39
|
+
|
40
|
+
|
41
|
+
class Trainer(basic.Trainer):
|
42
|
+
"""The federated learning trainer for the image segmentation on PASCAL VOC"""
|
43
|
+
|
44
|
+
def __init__(self, model=None, **kwargs):
|
45
|
+
"""Initializing the trainer with the provided model.
|
46
|
+
|
47
|
+
Arguments:
|
48
|
+
model: The model to train.
|
49
|
+
client_id: The ID of the client using this trainer (optional).
|
50
|
+
"""
|
51
|
+
super().__init__(model)
|
52
|
+
|
53
|
+
self.loss_criterion = nn.BCEWithLogitsLoss()
|
54
|
+
self.num_class = 20
|
55
|
+
|
56
|
+
def test_model(self, config, testset, sampler=None, **kwargs):
|
57
|
+
test_loader = torch.utils.data.DataLoader(
|
58
|
+
testset, batch_size=config["batch_size"], shuffle=False
|
59
|
+
)
|
60
|
+
|
61
|
+
total = 0
|
62
|
+
evaluator = Evaluator(self.num_class)
|
63
|
+
evaluator.reset()
|
64
|
+
with torch.no_grad():
|
65
|
+
for examples, labels in test_loader:
|
66
|
+
examples, labels = examples.to(self.device), labels.to(self.device)
|
67
|
+
|
68
|
+
outputs = self.model(examples)
|
69
|
+
|
70
|
+
_, predicted = torch.max(outputs.data, 1)
|
71
|
+
total += labels.size(0)
|
72
|
+
labels = torch.squeeze(labels, 1).cpu().numpy()
|
73
|
+
predicted = predicted.cpu().numpy()
|
74
|
+
print("shape of pred: ", predicted.shape)
|
75
|
+
print("shape of labels: ", labels.shape)
|
76
|
+
evaluator.add_batch(labels, predicted)
|
77
|
+
|
78
|
+
accuracy = evaluator.Mean_Intersection_over_Union()
|
79
|
+
|
80
|
+
return accuracy
|
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
Having a registry of all available classes is convenient for retrieving an instance
|
3
|
+
based on a configuration at run-time.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
|
8
|
+
from plato.config import Config
|
9
|
+
|
10
|
+
|
11
|
+
from plato.trainers import basic, diff_privacy, pascal_voc, gan, split_learning
|
12
|
+
|
13
|
+
registered_trainers = {
|
14
|
+
"basic": basic.Trainer,
|
15
|
+
"timm_basic": basic.TrainerWithTimmScheduler,
|
16
|
+
"diff_privacy": diff_privacy.Trainer,
|
17
|
+
"pascal_voc": pascal_voc.Trainer,
|
18
|
+
"gan": gan.Trainer,
|
19
|
+
"split_learning": split_learning.Trainer,
|
20
|
+
}
|
21
|
+
|
22
|
+
|
23
|
+
def get(model=None, callbacks=None):
|
24
|
+
"""Get the trainer with the provided name."""
|
25
|
+
trainer_name = Config().trainer.type
|
26
|
+
logging.info("Trainer: %s", trainer_name)
|
27
|
+
|
28
|
+
if Config().trainer.model_name == "yolov8":
|
29
|
+
from plato.trainers import yolov8
|
30
|
+
|
31
|
+
return yolov8.Trainer()
|
32
|
+
elif Config().trainer.type == "HuggingFace":
|
33
|
+
from plato.trainers import huggingface
|
34
|
+
|
35
|
+
return huggingface.Trainer(model=model, callbacks=callbacks)
|
36
|
+
|
37
|
+
elif Config().trainer.type == "self_supervised_learning":
|
38
|
+
from plato.trainers import self_supervised_learning
|
39
|
+
|
40
|
+
return self_supervised_learning.Trainer(model=model, callbacks=callbacks)
|
41
|
+
elif trainer_name in registered_trainers:
|
42
|
+
return registered_trainers[trainer_name](model=model, callbacks=callbacks)
|
43
|
+
else:
|
44
|
+
raise ValueError(f"No such trainer: {trainer_name}")
|