plato-learn 1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. plato/__init__.py +1 -0
  2. plato/algorithms/__init__.py +0 -0
  3. plato/algorithms/base.py +45 -0
  4. plato/algorithms/fedavg.py +48 -0
  5. plato/algorithms/fedavg_gan.py +79 -0
  6. plato/algorithms/fedavg_personalized.py +48 -0
  7. plato/algorithms/mistnet.py +52 -0
  8. plato/algorithms/registry.py +39 -0
  9. plato/algorithms/split_learning.py +89 -0
  10. plato/callbacks/__init__.py +0 -0
  11. plato/callbacks/client.py +56 -0
  12. plato/callbacks/handler.py +78 -0
  13. plato/callbacks/server.py +139 -0
  14. plato/callbacks/trainer.py +124 -0
  15. plato/client.py +67 -0
  16. plato/clients/__init__.py +0 -0
  17. plato/clients/base.py +467 -0
  18. plato/clients/edge.py +103 -0
  19. plato/clients/fedavg_personalized.py +40 -0
  20. plato/clients/mistnet.py +49 -0
  21. plato/clients/registry.py +43 -0
  22. plato/clients/self_supervised_learning.py +51 -0
  23. plato/clients/simple.py +218 -0
  24. plato/clients/split_learning.py +150 -0
  25. plato/config.py +339 -0
  26. plato/datasources/__init__.py +0 -0
  27. plato/datasources/base.py +123 -0
  28. plato/datasources/celeba.py +150 -0
  29. plato/datasources/cifar10.py +87 -0
  30. plato/datasources/cifar100.py +61 -0
  31. plato/datasources/cinic10.py +62 -0
  32. plato/datasources/coco.py +119 -0
  33. plato/datasources/datalib/__init__.py +0 -0
  34. plato/datasources/datalib/audio_extraction_tools.py +137 -0
  35. plato/datasources/datalib/data_utils.py +124 -0
  36. plato/datasources/datalib/flickr30kE_utils.py +336 -0
  37. plato/datasources/datalib/frames_extraction_tools.py +254 -0
  38. plato/datasources/datalib/gym_utils/__init__.py +0 -0
  39. plato/datasources/datalib/gym_utils/gym_trim.py +189 -0
  40. plato/datasources/datalib/modality_data_anntation_tools.py +163 -0
  41. plato/datasources/datalib/modality_extraction_base.py +59 -0
  42. plato/datasources/datalib/parse_datasets.py +212 -0
  43. plato/datasources/datalib/refer_utils/__init__.py +0 -0
  44. plato/datasources/datalib/refer_utils/referitgame_utils.py +237 -0
  45. plato/datasources/datalib/tiny_data_tools.py +81 -0
  46. plato/datasources/datalib/video_transform.py +79 -0
  47. plato/datasources/emnist.py +64 -0
  48. plato/datasources/fashion_mnist.py +41 -0
  49. plato/datasources/feature.py +24 -0
  50. plato/datasources/feature_dataset.py +15 -0
  51. plato/datasources/femnist.py +141 -0
  52. plato/datasources/flickr30k_entities.py +362 -0
  53. plato/datasources/gym.py +431 -0
  54. plato/datasources/huggingface.py +165 -0
  55. plato/datasources/kinetics.py +568 -0
  56. plato/datasources/mnist.py +44 -0
  57. plato/datasources/multimodal_base.py +328 -0
  58. plato/datasources/pascal_voc.py +56 -0
  59. plato/datasources/purchase.py +94 -0
  60. plato/datasources/qoenflx.py +127 -0
  61. plato/datasources/referitgame.py +330 -0
  62. plato/datasources/registry.py +119 -0
  63. plato/datasources/self_supervised_learning.py +98 -0
  64. plato/datasources/stl10.py +103 -0
  65. plato/datasources/texas.py +94 -0
  66. plato/datasources/tiny_imagenet.py +64 -0
  67. plato/datasources/yolov8.py +85 -0
  68. plato/models/__init__.py +0 -0
  69. plato/models/cnn_encoder.py +103 -0
  70. plato/models/dcgan.py +116 -0
  71. plato/models/general_multilayer.py +254 -0
  72. plato/models/huggingface.py +27 -0
  73. plato/models/lenet5.py +113 -0
  74. plato/models/multilayer.py +90 -0
  75. plato/models/multimodal/__init__.py +0 -0
  76. plato/models/multimodal/base_net.py +91 -0
  77. plato/models/multimodal/blending.py +142 -0
  78. plato/models/multimodal/fc_net.py +77 -0
  79. plato/models/multimodal/fusion_net.py +78 -0
  80. plato/models/multimodal/multimodal_module.py +152 -0
  81. plato/models/registry.py +99 -0
  82. plato/models/resnet.py +190 -0
  83. plato/models/torch_hub.py +19 -0
  84. plato/models/vgg.py +113 -0
  85. plato/models/vit.py +166 -0
  86. plato/models/yolov8.py +22 -0
  87. plato/processors/__init__.py +0 -0
  88. plato/processors/base.py +35 -0
  89. plato/processors/compress.py +46 -0
  90. plato/processors/decompress.py +48 -0
  91. plato/processors/feature.py +51 -0
  92. plato/processors/feature_additive_noise.py +48 -0
  93. plato/processors/feature_dequantize.py +34 -0
  94. plato/processors/feature_gaussian.py +17 -0
  95. plato/processors/feature_laplace.py +15 -0
  96. plato/processors/feature_quantize.py +34 -0
  97. plato/processors/feature_randomized_response.py +50 -0
  98. plato/processors/feature_unbatch.py +39 -0
  99. plato/processors/inbound_feature_tensors.py +39 -0
  100. plato/processors/model.py +55 -0
  101. plato/processors/model_compress.py +34 -0
  102. plato/processors/model_decompress.py +37 -0
  103. plato/processors/model_decrypt.py +41 -0
  104. plato/processors/model_deepcopy.py +21 -0
  105. plato/processors/model_dequantize.py +18 -0
  106. plato/processors/model_dequantize_qsgd.py +61 -0
  107. plato/processors/model_encrypt.py +43 -0
  108. plato/processors/model_quantize.py +18 -0
  109. plato/processors/model_quantize_qsgd.py +82 -0
  110. plato/processors/model_randomized_response.py +34 -0
  111. plato/processors/outbound_feature_ndarrays.py +38 -0
  112. plato/processors/pipeline.py +26 -0
  113. plato/processors/registry.py +124 -0
  114. plato/processors/structured_pruning.py +57 -0
  115. plato/processors/unstructured_pruning.py +73 -0
  116. plato/samplers/__init__.py +0 -0
  117. plato/samplers/all_inclusive.py +41 -0
  118. plato/samplers/base.py +31 -0
  119. plato/samplers/dirichlet.py +81 -0
  120. plato/samplers/distribution_noniid.py +132 -0
  121. plato/samplers/iid.py +53 -0
  122. plato/samplers/label_quantity_noniid.py +119 -0
  123. plato/samplers/mixed.py +44 -0
  124. plato/samplers/mixed_label_quantity_noniid.py +128 -0
  125. plato/samplers/modality_iid.py +42 -0
  126. plato/samplers/modality_quantity_noniid.py +56 -0
  127. plato/samplers/orthogonal.py +99 -0
  128. plato/samplers/registry.py +66 -0
  129. plato/samplers/sample_quantity_noniid.py +123 -0
  130. plato/samplers/sampler_utils.py +190 -0
  131. plato/servers/__init__.py +0 -0
  132. plato/servers/base.py +1395 -0
  133. plato/servers/fedavg.py +281 -0
  134. plato/servers/fedavg_cs.py +335 -0
  135. plato/servers/fedavg_gan.py +74 -0
  136. plato/servers/fedavg_he.py +106 -0
  137. plato/servers/fedavg_personalized.py +57 -0
  138. plato/servers/mistnet.py +67 -0
  139. plato/servers/registry.py +52 -0
  140. plato/servers/split_learning.py +109 -0
  141. plato/trainers/__init__.py +0 -0
  142. plato/trainers/base.py +99 -0
  143. plato/trainers/basic.py +649 -0
  144. plato/trainers/diff_privacy.py +178 -0
  145. plato/trainers/gan.py +330 -0
  146. plato/trainers/huggingface.py +173 -0
  147. plato/trainers/loss_criterion.py +70 -0
  148. plato/trainers/lr_schedulers.py +252 -0
  149. plato/trainers/optimizers.py +53 -0
  150. plato/trainers/pascal_voc.py +80 -0
  151. plato/trainers/registry.py +44 -0
  152. plato/trainers/self_supervised_learning.py +302 -0
  153. plato/trainers/split_learning.py +305 -0
  154. plato/trainers/tracking.py +96 -0
  155. plato/trainers/yolov8.py +41 -0
  156. plato/utils/__init__.py +0 -0
  157. plato/utils/count_parameters.py +30 -0
  158. plato/utils/csv_processor.py +26 -0
  159. plato/utils/data_loaders.py +148 -0
  160. plato/utils/decorators.py +24 -0
  161. plato/utils/fonts.py +23 -0
  162. plato/utils/homo_enc.py +187 -0
  163. plato/utils/reinforcement_learning/__init__.py +0 -0
  164. plato/utils/reinforcement_learning/policies/__init__.py +0 -0
  165. plato/utils/reinforcement_learning/policies/base.py +161 -0
  166. plato/utils/reinforcement_learning/policies/ddpg.py +75 -0
  167. plato/utils/reinforcement_learning/policies/registry.py +32 -0
  168. plato/utils/reinforcement_learning/policies/sac.py +343 -0
  169. plato/utils/reinforcement_learning/policies/td3.py +485 -0
  170. plato/utils/reinforcement_learning/rl_agent.py +142 -0
  171. plato/utils/reinforcement_learning/rl_server.py +113 -0
  172. plato/utils/rl_env.py +154 -0
  173. plato/utils/s3.py +141 -0
  174. plato/utils/trainer_utils.py +21 -0
  175. plato/utils/unary_encoding.py +47 -0
  176. plato_learn-1.1.dist-info/METADATA +35 -0
  177. plato_learn-1.1.dist-info/RECORD +179 -0
  178. plato_learn-1.1.dist-info/WHEEL +4 -0
  179. plato_learn-1.1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,70 @@
1
+ """
2
+ Obtaining the loss criterion for training workloads according to the configuration file.
3
+ """
4
+
5
+ from typing import Union
6
+
7
+ from torch import nn
8
+ from lightly import loss
9
+
10
+ from plato.config import Config
11
+
12
+
13
+ def get(**kwargs: Union[str, dict]):
14
+ """Get a loss function with its name from the configuration file."""
15
+ registered_loss_criterion = {
16
+ "L1Loss": nn.L1Loss,
17
+ "MSELoss": nn.MSELoss,
18
+ "BCELoss": nn.BCELoss,
19
+ "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
20
+ "NLLLoss": nn.NLLLoss,
21
+ "PoissonNLLLoss": nn.PoissonNLLLoss,
22
+ "CrossEntropyLoss": nn.CrossEntropyLoss,
23
+ "HingeEmbeddingLoss": nn.HingeEmbeddingLoss,
24
+ "MarginRankingLoss": nn.MarginRankingLoss,
25
+ "TripletMarginLoss": nn.TripletMarginLoss,
26
+ "KLDivLoss": nn.KLDivLoss,
27
+ }
28
+
29
+ ssl_loss_criterion = {
30
+ "NegativeCosineSimilarity": loss.NegativeCosineSimilarity,
31
+ "NTXentLoss": loss.NTXentLoss,
32
+ "BarlowTwinsLoss": loss.BarlowTwinsLoss,
33
+ "DCLLoss": loss.DCLLoss,
34
+ "DCLWLoss": loss.DCLWLoss,
35
+ "DINOLoss": loss.DINOLoss,
36
+ "PMSNCustomLoss": loss.PMSNCustomLoss,
37
+ "SwaVLoss": loss.SwaVLoss,
38
+ "PMSNLoss": loss.PMSNLoss,
39
+ "SymNegCosineSimilarityLoss": loss.SymNegCosineSimilarityLoss,
40
+ "TiCoLoss": loss.TiCoLoss,
41
+ "VICRegLoss": loss.VICRegLoss,
42
+ "VICRegLLoss": loss.VICRegLLoss,
43
+ "MSNLoss": loss.MSNLoss,
44
+ }
45
+
46
+ registered_loss_criterion.update(ssl_loss_criterion)
47
+
48
+ loss_criterion_name = (
49
+ kwargs["loss_criterion"]
50
+ if "loss_criterion" in kwargs
51
+ else (
52
+ Config().trainer.loss_criterion
53
+ if hasattr(Config.trainer, "loss_criterion")
54
+ else "CrossEntropyLoss"
55
+ )
56
+ )
57
+
58
+ loss_criterion_params = (
59
+ kwargs["loss_criterion_params"]
60
+ if "loss_criterion_params" in kwargs
61
+ else (
62
+ Config().parameters.loss_criterion._asdict()
63
+ if hasattr(Config.parameters, "loss_criterion")
64
+ else {}
65
+ )
66
+ )
67
+
68
+ loss_criterion = registered_loss_criterion.get(loss_criterion_name)
69
+
70
+ return loss_criterion(**loss_criterion_params)
@@ -0,0 +1,252 @@
1
+ """
2
+ Returns a learning rate scheduler according to the configuration.
3
+ """
4
+
5
+ import bisect
6
+ import sys
7
+ from types import SimpleNamespace
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ from timm import scheduler
12
+ from torch import optim
13
+
14
+ from plato.config import Config
15
+
16
+
17
+ def get(
18
+ optimizer: optim.Optimizer, iterations_per_epoch: int, **kwargs: Union[str, dict]
19
+ ):
20
+ """Returns a learning rate scheduler according to the configuration."""
21
+
22
+ registered_schedulers = {
23
+ "CosineAnnealingLR": optim.lr_scheduler.CosineAnnealingLR,
24
+ "LambdaLR": optim.lr_scheduler.LambdaLR,
25
+ "MultiStepLR": optim.lr_scheduler.MultiStepLR,
26
+ "StepLR": optim.lr_scheduler.StepLR,
27
+ "ReduceLROnPlateau": optim.lr_scheduler.ReduceLROnPlateau,
28
+ "ConstantLR": optim.lr_scheduler.ConstantLR,
29
+ "LinearLR": optim.lr_scheduler.LinearLR,
30
+ "ExponentialLR": optim.lr_scheduler.ExponentialLR,
31
+ "CyclicLR": optim.lr_scheduler.CyclicLR,
32
+ "CosineAnnealingWarmRestarts": optim.lr_scheduler.CosineAnnealingWarmRestarts,
33
+ }
34
+
35
+ registered_factories = {
36
+ "timm": scheduler.create_scheduler,
37
+ }
38
+
39
+ _scheduler = (
40
+ kwargs["lr_scheduler"]
41
+ if "lr_scheduler" in kwargs
42
+ else Config().trainer.lr_scheduler
43
+ )
44
+ lr_params = (
45
+ kwargs["lr_params"]
46
+ if "lr_params" in kwargs
47
+ else Config().parameters.learning_rate._asdict()
48
+ )
49
+
50
+ # First, look up the registered factories of LR schedulers
51
+ if _scheduler in registered_factories:
52
+ scheduler_args = SimpleNamespace(**lr_params)
53
+ scheduler_args.epochs = Config().trainer.epochs
54
+ lr_scheduler, __ = registered_factories[_scheduler](
55
+ args=scheduler_args, optimizer=optimizer
56
+ )
57
+ return lr_scheduler
58
+
59
+ # The list containing the learning rate schedulers that must be returned or
60
+ # the learning rate schedulers that ChainedScheduler or SequentialLR will
61
+ # take as an argument.
62
+ returned_schedulers = []
63
+
64
+ use_chained = False
65
+ use_sequential = False
66
+ if "ChainedScheduler" in _scheduler:
67
+ use_chained = True
68
+ lr_scheduler = [
69
+ sched for sched in _scheduler.split(",") if sched != ("ChainedScheduler")
70
+ ]
71
+ elif "SequentialLR" in _scheduler:
72
+ use_sequential = True
73
+ lr_scheduler = [
74
+ sched for sched in _scheduler.split(",") if sched != ("SequentialLR")
75
+ ]
76
+ else:
77
+ lr_scheduler = [_scheduler]
78
+
79
+ for _scheduler in lr_scheduler:
80
+ retrieved_scheduler = registered_schedulers.get(_scheduler)
81
+
82
+ if retrieved_scheduler is None:
83
+ sys.exit("Error: Unknown learning rate scheduler.")
84
+
85
+ if _scheduler == "CosineAnnealingLR":
86
+ returned_schedulers.append(
87
+ retrieved_scheduler(
88
+ optimizer, iterations_per_epoch * Config().trainer.epochs
89
+ )
90
+ )
91
+ elif _scheduler == "LambdaLR":
92
+ lambdas = [lambda it: 1.0]
93
+
94
+ if "gamma" in lr_params and "milestone_steps" in lr_params:
95
+ milestones = [
96
+ Step.from_str(x, iterations_per_epoch).iteration
97
+ for x in lr_params["milestone_steps"].split(",")
98
+ ]
99
+ lambdas.append(
100
+ lambda it, milestones=milestones: lr_params["gamma"]
101
+ ** bisect.bisect(milestones, it)
102
+ )
103
+
104
+ # Add a linear learning rate warmup if specified
105
+ if "warmup_steps" in lr_params:
106
+ warmup_iters = Step.from_str(
107
+ lr_params["warmup_steps"], iterations_per_epoch
108
+ ).iteration
109
+ lambdas.append(
110
+ lambda it, warmup_iters=warmup_iters: min(1.0, it / warmup_iters)
111
+ )
112
+ returned_schedulers.append(
113
+ retrieved_scheduler(
114
+ optimizer,
115
+ lambda it, lambdas=lambdas: np.prod([l(it) for l in lambdas]),
116
+ )
117
+ )
118
+ elif _scheduler == "MultiStepLR":
119
+ milestones = [
120
+ int(x.split("ep")[0]) for x in lr_params["milestone_steps"].split(",")
121
+ ]
122
+ returned_schedulers.append(
123
+ retrieved_scheduler(
124
+ optimizer, milestones=milestones, gamma=lr_params["gamma"]
125
+ )
126
+ )
127
+ else:
128
+ returned_schedulers.append(retrieved_scheduler(optimizer, **lr_params))
129
+
130
+ if use_chained:
131
+ return optim.lr_scheduler.ChainedScheduler(returned_schedulers)
132
+
133
+ if use_sequential:
134
+ sequential_milestones = (
135
+ Config().trainer.lr_sequential_milestones
136
+ if hasattr(Config().trainer, "lr_sequential_milestones")
137
+ else 2
138
+ )
139
+ sequential_milestones = [
140
+ int(epoch) for epoch in sequential_milestones.split(",")
141
+ ]
142
+
143
+ return optim.lr_scheduler.SequentialLR(
144
+ optimizer, returned_schedulers, sequential_milestones
145
+ )
146
+ else:
147
+ return returned_schedulers[0]
148
+
149
+
150
+ class Step:
151
+ """Represents a particular step of training."""
152
+
153
+ def __init__(self, iteration: int, iterations_per_epoch: int) -> None:
154
+ if iteration < 0:
155
+ raise ValueError("iteration must >= 0.")
156
+ if iterations_per_epoch <= 0:
157
+ raise ValueError("iterations_per_epoch must be > 0.")
158
+ self._iteration = iteration
159
+ self.iterations_per_epoch = iterations_per_epoch
160
+
161
+ @staticmethod
162
+ def str_is_zero(s: str) -> bool:
163
+ return s in ["0ep", "0it", "0ep0it"]
164
+
165
+ @staticmethod
166
+ def from_iteration(iteration: int, iterations_per_epoch: int) -> "Step":
167
+ return Step(iteration, iterations_per_epoch)
168
+
169
+ @staticmethod
170
+ def from_epoch(epoch: int, iteration: int, iterations_per_epoch: int) -> "Step":
171
+ return Step(epoch * iterations_per_epoch + iteration, iterations_per_epoch)
172
+
173
+ @staticmethod
174
+ def from_str(s: str, iterations_per_epoch: int) -> "Step":
175
+ """Creates a step from a string that describes the number of epochs, iterations, or both.
176
+
177
+ Epochs: '120ep'
178
+ Iterations: '2000it'
179
+ Both: '120ep50it'"""
180
+
181
+ if "ep" in s and "it" in s:
182
+ ep = int(s.split("ep")[0])
183
+ it = int(s.split("ep")[1].split("it")[0])
184
+ if s != "{}ep{}it".format(ep, it):
185
+ raise ValueError(f"Malformed string step: {s}")
186
+ return Step.from_epoch(ep, it, iterations_per_epoch)
187
+ elif "ep" in s:
188
+ ep = int(s.split("ep")[0])
189
+ if s != "{}ep".format(ep):
190
+ raise ValueError(f"Malformed string step: {s}")
191
+ return Step.from_epoch(ep, 0, iterations_per_epoch)
192
+ elif "it" in s:
193
+ it = int(s.split("it")[0])
194
+ if s != "{}it".format(it):
195
+ raise ValueError(f"Malformed string step: {s}")
196
+ return Step.from_iteration(it, iterations_per_epoch)
197
+ else:
198
+ raise ValueError(f"Malformed string step: {s}")
199
+
200
+ @staticmethod
201
+ def zero(iterations_per_epoch: int) -> "Step":
202
+ return Step(0, iterations_per_epoch)
203
+
204
+ @property
205
+ def iteration(self):
206
+ """The overall number of steps of training completed so far."""
207
+ return self._iteration
208
+
209
+ @property
210
+ def ep(self):
211
+ """The current epoch of training."""
212
+ return self._iteration // self.iterations_per_epoch
213
+
214
+ @property
215
+ def it(self):
216
+ """The iteration within the current epoch of training."""
217
+ return self._iteration % self.iterations_per_epoch
218
+
219
+ def _check(self, other):
220
+ if not isinstance(other, Step):
221
+ raise ValueError(f"Invalid type for other: {other}.")
222
+ if self.iterations_per_epoch != other.iterations_per_epoch:
223
+ raise ValueError(
224
+ "Cannot compare steps when epochs are of different lengths."
225
+ )
226
+
227
+ def __lt__(self, other):
228
+ self._check(other)
229
+ return self._iteration < other.iteration
230
+
231
+ def __le__(self, other):
232
+ self._check(other)
233
+ return self._iteration <= other.iteration
234
+
235
+ def __eq__(self, other):
236
+ self._check(other)
237
+ return self._iteration == other.iteration
238
+
239
+ def __ne__(self, other):
240
+ self._check(other)
241
+ return self._iteration != other.iteration
242
+
243
+ def __gt__(self, other):
244
+ self._check(other)
245
+ return self._iteration > other.iteration
246
+
247
+ def __ge__(self, other):
248
+ self._check(other)
249
+ return self._iteration >= other.iteration
250
+
251
+ def __str__(self):
252
+ return f"(Iteration {self._iteration}; Iterations per Epoch: {self.iterations_per_epoch})"
@@ -0,0 +1,53 @@
1
+ """
2
+ Optimizers for training workloads.
3
+ """
4
+
5
+ from typing import Union
6
+
7
+ import torch_optimizer as torch_optim
8
+ from torch import optim
9
+ from timm import optim as timm_optim
10
+
11
+ from plato.config import Config
12
+
13
+
14
+ def get(model, **kwargs: Union[str, dict]) -> optim.Optimizer:
15
+ """Get an optimizer with its name and parameters obtained from the configuration file."""
16
+ registered_optimizers = {
17
+ "Adam": optim.Adam,
18
+ "Adadelta": optim.Adadelta,
19
+ "Adagrad": optim.Adagrad,
20
+ "AdaHessian": torch_optim.Adahessian,
21
+ "AdamW": optim.AdamW,
22
+ "SparseAdam": optim.SparseAdam,
23
+ "Adamax": optim.Adamax,
24
+ "ASGD": optim.ASGD,
25
+ "LBFGS": optim.LBFGS,
26
+ "NAdam": optim.NAdam,
27
+ "RAdam": optim.RAdam,
28
+ "RMSprop": optim.RMSprop,
29
+ "Rprop": optim.Rprop,
30
+ "SGD": optim.SGD,
31
+ "LARS": timm_optim.lars.Lars,
32
+ }
33
+
34
+ optimizer_name = (
35
+ kwargs["optimizer_name"]
36
+ if "optimizer_name" in kwargs
37
+ else Config().trainer.optimizer
38
+ )
39
+ optimizer_params = (
40
+ kwargs["optimizer_params"]
41
+ if "optimizer_params" in kwargs
42
+ else Config().parameters.optimizer._asdict()
43
+ )
44
+
45
+ # Ensure eps is a float
46
+ if "eps" in optimizer_params:
47
+ optimizer_params["eps"] = float(optimizer_params["eps"])
48
+
49
+ optimizer = registered_optimizers.get(optimizer_name)
50
+ if optimizer is not None:
51
+ return optimizer(model.parameters(), **optimizer_params)
52
+
53
+ raise ValueError(f"No such optimizer: {optimizer_name}")
@@ -0,0 +1,80 @@
1
+ """
2
+ A customized trainer for image segmentation on PASCAL VOC dataset (2012).
3
+ """
4
+
5
+ import torch.nn as nn
6
+ import torch
7
+ import numpy as np
8
+
9
+ from plato.trainers import basic
10
+
11
+
12
+ class Evaluator(object):
13
+ def __init__(self, num_class):
14
+ self.num_class = num_class
15
+ self.confusion_matrix = np.zeros((self.num_class,) * 2)
16
+
17
+ def Mean_Intersection_over_Union(self):
18
+ MIoU = np.diag(self.confusion_matrix) / (
19
+ np.sum(self.confusion_matrix, axis=1)
20
+ + np.sum(self.confusion_matrix, axis=0)
21
+ - np.diag(self.confusion_matrix)
22
+ )
23
+ MIoU = np.nanmean(MIoU)
24
+ return MIoU
25
+
26
+ def _generate_matrix(self, gt_image, pre_image):
27
+ mask = (gt_image >= 0) & (gt_image < self.num_class)
28
+ label = self.num_class * gt_image[mask].astype("int") + pre_image[mask]
29
+ count = np.bincount(label, minlength=self.num_class**2)
30
+ confusion_matrix = count.reshape(self.num_class, self.num_class)
31
+ return confusion_matrix
32
+
33
+ def add_batch(self, gt_image, pre_image):
34
+ assert gt_image.shape == pre_image.shape
35
+ self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
36
+
37
+ def reset(self):
38
+ self.confusion_matrix = np.zeros((self.num_class,) * 2)
39
+
40
+
41
+ class Trainer(basic.Trainer):
42
+ """The federated learning trainer for the image segmentation on PASCAL VOC"""
43
+
44
+ def __init__(self, model=None, **kwargs):
45
+ """Initializing the trainer with the provided model.
46
+
47
+ Arguments:
48
+ model: The model to train.
49
+ client_id: The ID of the client using this trainer (optional).
50
+ """
51
+ super().__init__(model)
52
+
53
+ self.loss_criterion = nn.BCEWithLogitsLoss()
54
+ self.num_class = 20
55
+
56
+ def test_model(self, config, testset, sampler=None, **kwargs):
57
+ test_loader = torch.utils.data.DataLoader(
58
+ testset, batch_size=config["batch_size"], shuffle=False
59
+ )
60
+
61
+ total = 0
62
+ evaluator = Evaluator(self.num_class)
63
+ evaluator.reset()
64
+ with torch.no_grad():
65
+ for examples, labels in test_loader:
66
+ examples, labels = examples.to(self.device), labels.to(self.device)
67
+
68
+ outputs = self.model(examples)
69
+
70
+ _, predicted = torch.max(outputs.data, 1)
71
+ total += labels.size(0)
72
+ labels = torch.squeeze(labels, 1).cpu().numpy()
73
+ predicted = predicted.cpu().numpy()
74
+ print("shape of pred: ", predicted.shape)
75
+ print("shape of labels: ", labels.shape)
76
+ evaluator.add_batch(labels, predicted)
77
+
78
+ accuracy = evaluator.Mean_Intersection_over_Union()
79
+
80
+ return accuracy
@@ -0,0 +1,44 @@
1
+ """
2
+ Having a registry of all available classes is convenient for retrieving an instance
3
+ based on a configuration at run-time.
4
+ """
5
+
6
+ import logging
7
+
8
+ from plato.config import Config
9
+
10
+
11
+ from plato.trainers import basic, diff_privacy, pascal_voc, gan, split_learning
12
+
13
+ registered_trainers = {
14
+ "basic": basic.Trainer,
15
+ "timm_basic": basic.TrainerWithTimmScheduler,
16
+ "diff_privacy": diff_privacy.Trainer,
17
+ "pascal_voc": pascal_voc.Trainer,
18
+ "gan": gan.Trainer,
19
+ "split_learning": split_learning.Trainer,
20
+ }
21
+
22
+
23
+ def get(model=None, callbacks=None):
24
+ """Get the trainer with the provided name."""
25
+ trainer_name = Config().trainer.type
26
+ logging.info("Trainer: %s", trainer_name)
27
+
28
+ if Config().trainer.model_name == "yolov8":
29
+ from plato.trainers import yolov8
30
+
31
+ return yolov8.Trainer()
32
+ elif Config().trainer.type == "HuggingFace":
33
+ from plato.trainers import huggingface
34
+
35
+ return huggingface.Trainer(model=model, callbacks=callbacks)
36
+
37
+ elif Config().trainer.type == "self_supervised_learning":
38
+ from plato.trainers import self_supervised_learning
39
+
40
+ return self_supervised_learning.Trainer(model=model, callbacks=callbacks)
41
+ elif trainer_name in registered_trainers:
42
+ return registered_trainers[trainer_name](model=model, callbacks=callbacks)
43
+ else:
44
+ raise ValueError(f"No such trainer: {trainer_name}")